Entferne Integrations-Tests und zugehörige Dateien; füge CLAUDE.md zur .gitignore hinzu

This commit is contained in:
Gregor Schulte
2026-02-17 11:16:36 +01:00
parent df4978cc04
commit ae8a90d1e8
6 changed files with 2 additions and 1282 deletions

View File

@@ -1,69 +0,0 @@
# Integration Tests
Integration tests require a running Prefect server instance.
## Running Integration Tests
### 1. Start a Prefect Server
```bash
prefect server start
```
This will start a Prefect server on `http://localhost:4200`.
### 2. Run the Integration Tests
```bash
go test -v -tags=integration ./pkg/client/
```
Or using make:
```bash
make test-integration
```
### 3. Custom Server URL
If your Prefect server is running on a different URL, set the `PREFECT_API_URL` environment variable:
```bash
export PREFECT_API_URL="http://your-server:4200/api"
go test -v -tags=integration ./pkg/client/
```
## Test Coverage
The integration tests cover:
- **Flow Lifecycle**: Create, get, update, list, and delete flows
- **Flow Run Lifecycle**: Create, get, set state, and delete flow runs
- **Deployment Lifecycle**: Create, get, pause, resume, and delete deployments
- **Pagination**: Manual pagination and iterator-based pagination
- **Variables**: Create, get, update, and delete variables
- **Admin Endpoints**: Health check and version
- **Flow Run Monitoring**: Wait for flow run completion
## Notes
- Integration tests use the `integration` build tag to separate them from unit tests
- Each test creates and cleans up its own resources
- Tests use unique UUIDs in names to avoid conflicts
- Some tests may take several seconds to complete (e.g., flow run wait tests)
## Troubleshooting
### Connection Refused
If you see "connection refused" errors, make sure:
1. Prefect server is running (`prefect server start`)
2. The server is accessible at the configured URL
3. No firewall is blocking the connection
### Test Timeout
If tests timeout:
1. Increase the test timeout: `go test -timeout 5m -tags=integration ./pkg/client/`
2. Check server logs for errors
3. Ensure the server is not under heavy load

View File

@@ -1,415 +0,0 @@
//go:build integration
// +build integration
package client_test
import (
"context"
"os"
"testing"
"time"
"github.com/google/uuid"
"github.com/gregor/prefect-go/pkg/client"
"github.com/gregor/prefect-go/pkg/errors"
"github.com/gregor/prefect-go/pkg/models"
)
var testClient *client.Client
func TestMain(m *testing.M) {
// Get base URL from environment or use default
baseURL := os.Getenv("PREFECT_API_URL")
if baseURL == "" {
baseURL = "http://localhost:4200/api"
}
// Create test client
var err error
testClient, err = client.NewClient(
client.WithBaseURL(baseURL),
)
if err != nil {
panic(err)
}
// Run tests
os.Exit(m.Run())
}
func TestIntegration_FlowLifecycle(t *testing.T) {
ctx := context.Background()
// Create a flow
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
Name: "test-flow-" + uuid.New().String(),
Tags: []string{"integration-test"},
})
if err != nil {
t.Fatalf("Failed to create flow: %v", err)
}
defer testClient.Flows.Delete(ctx, flow.ID)
// Verify flow was created
if flow.ID == uuid.Nil {
t.Error("Flow ID is nil")
}
if flow.Name == "" {
t.Error("Flow name is empty")
}
// Get the flow
retrievedFlow, err := testClient.Flows.Get(ctx, flow.ID)
if err != nil {
t.Fatalf("Failed to get flow: %v", err)
}
if retrievedFlow.ID != flow.ID {
t.Errorf("Flow ID mismatch: got %v, want %v", retrievedFlow.ID, flow.ID)
}
// Update the flow
newTags := []string{"integration-test", "updated"}
updatedFlow, err := testClient.Flows.Update(ctx, flow.ID, &models.FlowUpdate{
Tags: &newTags,
})
if err != nil {
t.Fatalf("Failed to update flow: %v", err)
}
if len(updatedFlow.Tags) != 2 {
t.Errorf("Expected 2 tags, got %d", len(updatedFlow.Tags))
}
// List flows
flowsPage, err := testClient.Flows.List(ctx, &models.FlowFilter{
Tags: []string{"integration-test"},
}, 0, 10)
if err != nil {
t.Fatalf("Failed to list flows: %v", err)
}
if len(flowsPage.Results) == 0 {
t.Error("Expected at least one flow in results")
}
// Delete the flow
if err := testClient.Flows.Delete(ctx, flow.ID); err != nil {
t.Fatalf("Failed to delete flow: %v", err)
}
// Verify deletion
_, err = testClient.Flows.Get(ctx, flow.ID)
if !errors.IsNotFound(err) {
t.Errorf("Expected NotFound error, got: %v", err)
}
}
func TestIntegration_FlowRunLifecycle(t *testing.T) {
ctx := context.Background()
// Create a flow first
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
Name: "test-flow-run-" + uuid.New().String(),
Tags: []string{"integration-test"},
})
if err != nil {
t.Fatalf("Failed to create flow: %v", err)
}
defer testClient.Flows.Delete(ctx, flow.ID)
// Create a flow run
flowRun, err := testClient.FlowRuns.Create(ctx, &models.FlowRunCreate{
FlowID: flow.ID,
Name: "test-run",
Parameters: map[string]interface{}{
"test_param": "value",
},
})
if err != nil {
t.Fatalf("Failed to create flow run: %v", err)
}
defer testClient.FlowRuns.Delete(ctx, flowRun.ID)
// Verify flow run was created
if flowRun.ID == uuid.Nil {
t.Error("Flow run ID is nil")
}
if flowRun.FlowID != flow.ID {
t.Errorf("Flow ID mismatch: got %v, want %v", flowRun.FlowID, flow.ID)
}
// Get the flow run
retrievedRun, err := testClient.FlowRuns.Get(ctx, flowRun.ID)
if err != nil {
t.Fatalf("Failed to get flow run: %v", err)
}
if retrievedRun.ID != flowRun.ID {
t.Errorf("Flow run ID mismatch: got %v, want %v", retrievedRun.ID, flowRun.ID)
}
// Set state to RUNNING
runningState := models.StateTypeRunning
updatedRun, err := testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{
Type: runningState,
Message: strPtr("Test running"),
})
if err != nil {
t.Fatalf("Failed to set state: %v", err)
}
if updatedRun.StateType == nil || *updatedRun.StateType != runningState {
t.Errorf("State type mismatch: got %v, want %v", updatedRun.StateType, runningState)
}
// Set state to COMPLETED
completedState := models.StateTypeCompleted
completedRun, err := testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{
Type: completedState,
Message: strPtr("Test completed"),
})
if err != nil {
t.Fatalf("Failed to set state: %v", err)
}
if completedRun.StateType == nil || *completedRun.StateType != completedState {
t.Errorf("State type mismatch: got %v, want %v", completedRun.StateType, completedState)
}
// List flow runs
runsPage, err := testClient.FlowRuns.List(ctx, &models.FlowRunFilter{
FlowID: &flow.ID,
}, 0, 10)
if err != nil {
t.Fatalf("Failed to list flow runs: %v", err)
}
if len(runsPage.Results) == 0 {
t.Error("Expected at least one flow run in results")
}
}
func TestIntegration_DeploymentLifecycle(t *testing.T) {
ctx := context.Background()
// Create a flow first
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
Name: "test-deployment-flow-" + uuid.New().String(),
Tags: []string{"integration-test"},
})
if err != nil {
t.Fatalf("Failed to create flow: %v", err)
}
defer testClient.Flows.Delete(ctx, flow.ID)
// Create a deployment
workPoolName := "default-pool"
deployment, err := testClient.Deployments.Create(ctx, &models.DeploymentCreate{
Name: "test-deployment",
FlowID: flow.ID,
WorkPoolName: &workPoolName,
})
if err != nil {
t.Fatalf("Failed to create deployment: %v", err)
}
defer testClient.Deployments.Delete(ctx, deployment.ID)
// Verify deployment was created
if deployment.ID == uuid.Nil {
t.Error("Deployment ID is nil")
}
if deployment.FlowID != flow.ID {
t.Errorf("Flow ID mismatch: got %v, want %v", deployment.FlowID, flow.ID)
}
// Get the deployment
retrievedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID)
if err != nil {
t.Fatalf("Failed to get deployment: %v", err)
}
if retrievedDeployment.ID != deployment.ID {
t.Errorf("Deployment ID mismatch: got %v, want %v", retrievedDeployment.ID, deployment.ID)
}
// Pause the deployment
if err := testClient.Deployments.Pause(ctx, deployment.ID); err != nil {
t.Fatalf("Failed to pause deployment: %v", err)
}
// Verify paused
pausedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID)
if err != nil {
t.Fatalf("Failed to get deployment: %v", err)
}
if !pausedDeployment.Paused {
t.Error("Expected deployment to be paused")
}
// Resume the deployment
if err := testClient.Deployments.Resume(ctx, deployment.ID); err != nil {
t.Fatalf("Failed to resume deployment: %v", err)
}
// Verify resumed
resumedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID)
if err != nil {
t.Fatalf("Failed to get deployment: %v", err)
}
if resumedDeployment.Paused {
t.Error("Expected deployment to be resumed")
}
}
func TestIntegration_Pagination(t *testing.T) {
ctx := context.Background()
// Create multiple flows
createdFlows := make([]uuid.UUID, 0)
for i := 0; i < 15; i++ {
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
Name: "pagination-test-" + uuid.New().String(),
Tags: []string{"pagination-integration-test"},
})
if err != nil {
t.Fatalf("Failed to create flow %d: %v", i, err)
}
createdFlows = append(createdFlows, flow.ID)
}
// Clean up
defer func() {
for _, id := range createdFlows {
testClient.Flows.Delete(ctx, id)
}
}()
// Test manual pagination
page1, err := testClient.Flows.List(ctx, &models.FlowFilter{
Tags: []string{"pagination-integration-test"},
}, 0, 5)
if err != nil {
t.Fatalf("Failed to list flows: %v", err)
}
if len(page1.Results) != 5 {
t.Errorf("Expected 5 flows in page 1, got %d", len(page1.Results))
}
if !page1.HasMore {
t.Error("Expected more pages")
}
// Test iterator
iter := testClient.Flows.ListAll(ctx, &models.FlowFilter{
Tags: []string{"pagination-integration-test"},
})
count := 0
for iter.Next(ctx) {
count++
}
if err := iter.Err(); err != nil {
t.Fatalf("Iterator error: %v", err)
}
if count != 15 {
t.Errorf("Expected to iterate over 15 flows, got %d", count)
}
}
func TestIntegration_Variables(t *testing.T) {
ctx := context.Background()
// Create a variable
variable, err := testClient.Variables.Create(ctx, &models.VariableCreate{
Name: "test-var-" + uuid.New().String(),
Value: "test-value",
Tags: []string{"integration-test"},
})
if err != nil {
t.Fatalf("Failed to create variable: %v", err)
}
defer testClient.Variables.Delete(ctx, variable.ID)
// Get the variable
retrievedVar, err := testClient.Variables.Get(ctx, variable.ID)
if err != nil {
t.Fatalf("Failed to get variable: %v", err)
}
if retrievedVar.Value != "test-value" {
t.Errorf("Value mismatch: got %v, want test-value", retrievedVar.Value)
}
// Update the variable
newValue := "updated-value"
updatedVar, err := testClient.Variables.Update(ctx, variable.ID, &models.VariableUpdate{
Value: &newValue,
})
if err != nil {
t.Fatalf("Failed to update variable: %v", err)
}
if updatedVar.Value != newValue {
t.Errorf("Value mismatch: got %v, want %v", updatedVar.Value, newValue)
}
}
func TestIntegration_AdminEndpoints(t *testing.T) {
ctx := context.Background()
// Test health check
if err := testClient.Admin.Health(ctx); err != nil {
t.Fatalf("Health check failed: %v", err)
}
// Test version
version, err := testClient.Admin.Version(ctx)
if err != nil {
t.Fatalf("Failed to get version: %v", err)
}
if version == "" {
t.Error("Version is empty")
}
t.Logf("Server version: %s", version)
}
func TestIntegration_FlowRunWait(t *testing.T) {
ctx := context.Background()
// Create a flow
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
Name: "wait-test-" + uuid.New().String(),
Tags: []string{"integration-test"},
})
if err != nil {
t.Fatalf("Failed to create flow: %v", err)
}
defer testClient.Flows.Delete(ctx, flow.ID)
// Create a flow run
flowRun, err := testClient.FlowRuns.Create(ctx, &models.FlowRunCreate{
FlowID: flow.ID,
Name: "wait-test-run",
})
if err != nil {
t.Fatalf("Failed to create flow run: %v", err)
}
defer testClient.FlowRuns.Delete(ctx, flowRun.ID)
// Simulate completion in background
go func() {
time.Sleep(2 * time.Second)
completedState := models.StateTypeCompleted
testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{
Type: completedState,
Message: strPtr("Test completed"),
})
}()
// Wait for completion with timeout
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
finalRun, err := testClient.FlowRuns.Wait(waitCtx, flowRun.ID, time.Second)
if err != nil {
t.Fatalf("Failed to wait for flow run: %v", err)
}
if finalRun.StateType == nil || *finalRun.StateType != models.StateTypeCompleted {
t.Errorf("Expected COMPLETED state, got %v", finalRun.StateType)
}
}
func strPtr(s string) *string {
return &s
}

View File

@@ -1,240 +0,0 @@
package errors
import (
"io"
"net/http"
"strings"
"testing"
)
func TestAPIError_Error(t *testing.T) {
tests := []struct {
name string
err *APIError
want string
}{
{
name: "with message",
err: &APIError{
StatusCode: 404,
Message: "Flow not found",
},
want: "prefect api error (status 404): Flow not found",
},
{
name: "without message",
err: &APIError{
StatusCode: 500,
},
want: "prefect api error (status 500)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.err.Error(); got != tt.want {
t.Errorf("Error() = %v, want %v", got, tt.want)
}
})
}
}
func TestAPIError_Checks(t *testing.T) {
tests := []struct {
name string
statusCode int
checkFuncs map[string]func(*APIError) bool
want map[string]bool
}{
{
name: "404 not found",
statusCode: 404,
checkFuncs: map[string]func(*APIError) bool{
"IsNotFound": (*APIError).IsNotFound,
"IsUnauthorized": (*APIError).IsUnauthorized,
"IsForbidden": (*APIError).IsForbidden,
"IsRateLimited": (*APIError).IsRateLimited,
"IsServerError": (*APIError).IsServerError,
},
want: map[string]bool{
"IsNotFound": true,
"IsUnauthorized": false,
"IsForbidden": false,
"IsRateLimited": false,
"IsServerError": false,
},
},
{
name: "401 unauthorized",
statusCode: 401,
checkFuncs: map[string]func(*APIError) bool{
"IsNotFound": (*APIError).IsNotFound,
"IsUnauthorized": (*APIError).IsUnauthorized,
"IsForbidden": (*APIError).IsForbidden,
"IsRateLimited": (*APIError).IsRateLimited,
"IsServerError": (*APIError).IsServerError,
},
want: map[string]bool{
"IsNotFound": false,
"IsUnauthorized": true,
"IsForbidden": false,
"IsRateLimited": false,
"IsServerError": false,
},
},
{
name: "500 server error",
statusCode: 500,
checkFuncs: map[string]func(*APIError) bool{
"IsNotFound": (*APIError).IsNotFound,
"IsUnauthorized": (*APIError).IsUnauthorized,
"IsForbidden": (*APIError).IsForbidden,
"IsRateLimited": (*APIError).IsRateLimited,
"IsServerError": (*APIError).IsServerError,
},
want: map[string]bool{
"IsNotFound": false,
"IsUnauthorized": false,
"IsForbidden": false,
"IsRateLimited": false,
"IsServerError": true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := &APIError{StatusCode: tt.statusCode}
for checkName, checkFunc := range tt.checkFuncs {
if got := checkFunc(err); got != tt.want[checkName] {
t.Errorf("%s() = %v, want %v", checkName, got, tt.want[checkName])
}
}
})
}
}
func TestNewAPIError(t *testing.T) {
tests := []struct {
name string
resp *http.Response
wantStatus int
wantMsg string
}{
{
name: "nil response",
resp: nil,
wantStatus: 0,
wantMsg: "nil response",
},
{
name: "404 with JSON body",
resp: &http.Response{
StatusCode: 404,
Body: io.NopCloser(strings.NewReader(`{"detail": "Flow not found"}`)),
Header: http.Header{},
},
wantStatus: 404,
wantMsg: "Flow not found",
},
{
name: "500 without body",
resp: &http.Response{
StatusCode: 500,
Body: io.NopCloser(strings.NewReader("")),
Header: http.Header{},
},
wantStatus: 500,
wantMsg: "Internal Server Error",
},
{
name: "with request ID",
resp: &http.Response{
StatusCode: 400,
Body: io.NopCloser(strings.NewReader(`{"detail": "Bad request"}`)),
Header: http.Header{
"X-Request-ID": []string{"req-123"},
},
},
wantStatus: 400,
wantMsg: "Bad request",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := NewAPIError(tt.resp)
apiErr, ok := err.(*APIError)
if !ok {
t.Fatalf("expected *APIError, got %T", err)
}
if apiErr.StatusCode != tt.wantStatus {
t.Errorf("StatusCode = %v, want %v", apiErr.StatusCode, tt.wantStatus)
}
if apiErr.Message != tt.wantMsg {
t.Errorf("Message = %v, want %v", apiErr.Message, tt.wantMsg)
}
})
}
}
func TestNewAPIError_Validation(t *testing.T) {
resp := &http.Response{
StatusCode: 422,
Body: io.NopCloser(strings.NewReader(`{
"detail": [
{
"loc": ["body", "name"],
"msg": "field required",
"type": "value_error.missing"
}
]
}`)),
Header: http.Header{},
}
err := NewAPIError(resp)
valErr, ok := err.(*ValidationError)
if !ok {
t.Fatalf("expected *ValidationError, got %T", err)
}
if valErr.StatusCode != 422 {
t.Errorf("StatusCode = %v, want 422", valErr.StatusCode)
}
if len(valErr.ValidationErrors) != 1 {
t.Fatalf("expected 1 validation error, got %d", len(valErr.ValidationErrors))
}
vd := valErr.ValidationErrors[0]
if len(vd.Loc) != 2 || vd.Loc[0] != "body" || vd.Loc[1] != "name" {
t.Errorf("Loc = %v, want [body name]", vd.Loc)
}
if vd.Msg != "field required" {
t.Errorf("Msg = %v, want 'field required'", vd.Msg)
}
if vd.Type != "value_error.missing" {
t.Errorf("Type = %v, want 'value_error.missing'", vd.Type)
}
}
func TestHelperFunctions(t *testing.T) {
notFoundErr := &APIError{StatusCode: 404}
unauthorizedErr := &APIError{StatusCode: 401}
serverErr := &APIError{StatusCode: 500}
valErr := &ValidationError{APIError: &APIError{StatusCode: 422}}
if !IsNotFound(notFoundErr) {
t.Error("IsNotFound should return true for 404 error")
}
if !IsUnauthorized(unauthorizedErr) {
t.Error("IsUnauthorized should return true for 401 error")
}
if !IsServerError(serverErr) {
t.Error("IsServerError should return true for 500 error")
}
if !IsValidationError(valErr) {
t.Error("IsValidationError should return true for validation error")
}
}

View File

@@ -1,339 +0,0 @@
package pagination
import (
"context"
"errors"
"testing"
)
func TestNewIterator(t *testing.T) {
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
return nil, nil
}
tests := []struct {
name string
limit int
wantLimit int
}{
{"default limit", 0, 100},
{"custom limit", 50, 50},
{"negative limit", -10, 100},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
iter := NewIterator(fetchFunc, tt.limit)
if iter.limit != tt.wantLimit {
t.Errorf("limit = %d, want %d", iter.limit, tt.wantLimit)
}
})
}
}
func TestIterator_Next(t *testing.T) {
items := []string{"item1", "item2", "item3", "item4", "item5"}
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
if offset >= len(items) {
return &PaginatedResponse[string]{
Results: []string{},
Count: len(items),
Limit: limit,
Offset: offset,
HasMore: false,
}, nil
}
end := offset + limit
if end > len(items) {
end = len(items)
}
return &PaginatedResponse[string]{
Results: items[offset:end],
Count: len(items),
Limit: limit,
Offset: offset,
HasMore: end < len(items),
}, nil
}
ctx := context.Background()
iter := NewIterator(fetchFunc, 2)
var results []string
for iter.Next(ctx) {
if item := iter.Value(); item != nil {
results = append(results, *item)
}
}
if err := iter.Err(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != len(items) {
t.Errorf("got %d items, want %d", len(results), len(items))
}
for i, want := range items {
if i >= len(results) || results[i] != want {
t.Errorf("item[%d] = %v, want %v", i, results[i], want)
}
}
}
func TestIterator_Error(t *testing.T) {
expectedErr := errors.New("fetch error")
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
return nil, expectedErr
}
ctx := context.Background()
iter := NewIterator(fetchFunc, 10)
if iter.Next(ctx) {
t.Error("Next should return false on error")
}
if err := iter.Err(); err != expectedErr {
t.Errorf("Err() = %v, want %v", err, expectedErr)
}
}
func TestIterator_EmptyResults(t *testing.T) {
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
return &PaginatedResponse[string]{
Results: []string{},
Count: 0,
Limit: limit,
Offset: offset,
HasMore: false,
}, nil
}
ctx := context.Background()
iter := NewIterator(fetchFunc, 10)
if iter.Next(ctx) {
t.Error("Next should return false for empty results")
}
}
func TestIterator_Reset(t *testing.T) {
items := []string{"item1", "item2", "item3"}
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
if offset >= len(items) {
return &PaginatedResponse[string]{
Results: []string{},
HasMore: false,
}, nil
}
return &PaginatedResponse[string]{
Results: items[offset:],
HasMore: false,
}, nil
}
ctx := context.Background()
iter := NewIterator(fetchFunc, 10)
// First iteration
count1 := 0
for iter.Next(ctx) {
count1++
}
// Reset and iterate again
iter.Reset()
count2 := 0
for iter.Next(ctx) {
count2++
}
if count1 != count2 {
t.Errorf("counts don't match after reset: %d vs %d", count1, count2)
}
if count1 != len(items) {
t.Errorf("got %d items, want %d", count1, len(items))
}
}
func TestIterator_Collect(t *testing.T) {
items := []string{"item1", "item2", "item3", "item4", "item5"}
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
if offset >= len(items) {
return &PaginatedResponse[string]{
Results: []string{},
HasMore: false,
}, nil
}
end := offset + limit
if end > len(items) {
end = len(items)
}
return &PaginatedResponse[string]{
Results: items[offset:end],
HasMore: end < len(items),
}, nil
}
ctx := context.Background()
iter := NewIterator(fetchFunc, 2)
results, err := iter.Collect(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != len(items) {
t.Errorf("got %d items, want %d", len(results), len(items))
}
}
func TestIterator_CollectWithLimit(t *testing.T) {
items := []string{"item1", "item2", "item3", "item4", "item5"}
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
if offset >= len(items) {
return &PaginatedResponse[string]{
Results: []string{},
HasMore: false,
}, nil
}
end := offset + limit
if end > len(items) {
end = len(items)
}
return &PaginatedResponse[string]{
Results: items[offset:end],
HasMore: end < len(items),
}, nil
}
ctx := context.Background()
iter := NewIterator(fetchFunc, 2)
maxItems := 3
results, err := iter.CollectWithLimit(ctx, maxItems)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != maxItems {
t.Errorf("got %d items, want %d", len(results), maxItems)
}
}
func TestPage_HasNextPage(t *testing.T) {
tests := []struct {
name string
page Page[string]
want bool
}{
{
name: "has next page",
page: Page[string]{
Items: []string{"a", "b"},
Offset: 0,
Limit: 2,
Total: 5,
},
want: true,
},
{
name: "no next page",
page: Page[string]{
Items: []string{"d", "e"},
Offset: 3,
Limit: 2,
Total: 5,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.page.HasNextPage(); got != tt.want {
t.Errorf("HasNextPage() = %v, want %v", got, tt.want)
}
})
}
}
func TestPage_HasPrevPage(t *testing.T) {
tests := []struct {
name string
page Page[string]
want bool
}{
{
name: "has previous page",
page: Page[string]{
Offset: 2,
},
want: true,
},
{
name: "no previous page",
page: Page[string]{
Offset: 0,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.page.HasPrevPage(); got != tt.want {
t.Errorf("HasPrevPage() = %v, want %v", got, tt.want)
}
})
}
}
func TestPage_NextOffset(t *testing.T) {
page := Page[string]{
Offset: 10,
Limit: 5,
}
if got := page.NextOffset(); got != 15 {
t.Errorf("NextOffset() = %v, want 15", got)
}
}
func TestPage_PrevOffset(t *testing.T) {
tests := []struct {
name string
offset int
limit int
want int
}{
{"normal case", 10, 5, 5},
{"at beginning", 3, 5, 0},
{"exact boundary", 5, 5, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
page := Page[string]{
Offset: tt.offset,
Limit: tt.limit,
}
if got := page.PrevOffset(); got != tt.want {
t.Errorf("PrevOffset() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -1,219 +0,0 @@
package retry
import (
"context"
"net/http"
"testing"
"time"
)
func TestNew(t *testing.T) {
tests := []struct {
name string
config Config
want Config
}{
{
name: "default config",
config: Config{},
want: DefaultConfig(),
},
{
name: "custom config",
config: Config{
MaxAttempts: 5,
InitialDelay: 2 * time.Second,
MaxDelay: 60 * time.Second,
BackoffFactor: 3.0,
RetryableErrors: []int{500},
},
want: Config{
MaxAttempts: 5,
InitialDelay: 2 * time.Second,
MaxDelay: 60 * time.Second,
BackoffFactor: 3.0,
RetryableErrors: []int{500},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := New(tt.config)
if r.config.MaxAttempts != tt.want.MaxAttempts {
t.Errorf("MaxAttempts = %v, want %v", r.config.MaxAttempts, tt.want.MaxAttempts)
}
if r.config.InitialDelay != tt.want.InitialDelay {
t.Errorf("InitialDelay = %v, want %v", r.config.InitialDelay, tt.want.InitialDelay)
}
})
}
}
func TestRetrier_isRetryable(t *testing.T) {
r := New(DefaultConfig())
tests := []struct {
name string
statusCode int
want bool
}{
{"429 rate limit", 429, true},
{"500 server error", 500, true},
{"502 bad gateway", 502, true},
{"503 service unavailable", 503, true},
{"504 gateway timeout", 504, true},
{"200 ok", 200, false},
{"400 bad request", 400, false},
{"404 not found", 404, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := &http.Response{StatusCode: tt.statusCode}
if got := r.isRetryable(resp); got != tt.want {
t.Errorf("isRetryable() = %v, want %v", got, tt.want)
}
})
}
}
func TestRetrier_Do_Success(t *testing.T) {
r := New(DefaultConfig())
ctx := context.Background()
attempts := 0
fn := func() (*http.Response, error) {
attempts++
return &http.Response{StatusCode: 200}, nil
}
resp, err := r.Do(ctx, fn)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != 200 {
t.Errorf("status = %d, want 200", resp.StatusCode)
}
if attempts != 1 {
t.Errorf("attempts = %d, want 1", attempts)
}
}
func TestRetrier_Do_Retry(t *testing.T) {
config := Config{
MaxAttempts: 3,
InitialDelay: 10 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
BackoffFactor: 2.0,
RetryableErrors: []int{500},
}
r := New(config)
ctx := context.Background()
attempts := 0
fn := func() (*http.Response, error) {
attempts++
if attempts < 3 {
return &http.Response{StatusCode: 500}, nil
}
return &http.Response{StatusCode: 200}, nil
}
start := time.Now()
resp, err := r.Do(ctx, fn)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != 200 {
t.Errorf("status = %d, want 200", resp.StatusCode)
}
if attempts != 3 {
t.Errorf("attempts = %d, want 3", attempts)
}
// Should have waited at least once
if elapsed < 10*time.Millisecond {
t.Errorf("elapsed time %v too short, expected at least 10ms", elapsed)
}
}
func TestRetrier_Do_ContextCancellation(t *testing.T) {
config := Config{
MaxAttempts: 5,
InitialDelay: 100 * time.Millisecond,
MaxDelay: 1 * time.Second,
BackoffFactor: 2.0,
RetryableErrors: []int{500},
}
r := New(config)
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
attempts := 0
fn := func() (*http.Response, error) {
attempts++
return &http.Response{StatusCode: 500}, nil
}
_, err := r.Do(ctx, fn)
if err != context.DeadlineExceeded {
t.Errorf("error = %v, want context.DeadlineExceeded", err)
}
if attempts > 2 {
t.Errorf("attempts = %d, should be cancelled early", attempts)
}
}
func TestRetrier_calculateDelay(t *testing.T) {
config := Config{
InitialDelay: time.Second,
MaxDelay: 30 * time.Second,
BackoffFactor: 2.0,
}
r := New(config)
tests := []struct {
name string
attempt int
resp *http.Response
minWant time.Duration
maxWant time.Duration
}{
{
name: "first retry",
attempt: 1,
resp: &http.Response{StatusCode: 500},
minWant: 500 * time.Millisecond, // with jitter min 0.5
maxWant: 1 * time.Second, // with jitter max 1.0
},
{
name: "second retry",
attempt: 2,
resp: &http.Response{StatusCode: 500},
minWant: 1 * time.Second, // 2^1 * 1s * 0.5
maxWant: 2 * time.Second, // 2^1 * 1s * 1.0
},
{
name: "with Retry-After header",
attempt: 1,
resp: &http.Response{
StatusCode: 429,
Header: http.Header{"Retry-After": []string{"5"}},
},
minWant: 5 * time.Second,
maxWant: 5 * time.Second,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
delay := r.calculateDelay(tt.attempt, tt.resp)
if delay < tt.minWant || delay > tt.maxWant {
t.Errorf("delay = %v, want between %v and %v", delay, tt.minWant, tt.maxWant)
}
})
}
}