diff --git a/.gitignore b/.gitignore index 3a06a30..9cc1f0b 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,5 @@ pkg/models/generated*.go # Temporary files tmp/ temp/ + +CLAUDE.md diff --git a/pkg/client/INTEGRATION_TESTS.md b/pkg/client/INTEGRATION_TESTS.md deleted file mode 100644 index bf87a85..0000000 --- a/pkg/client/INTEGRATION_TESTS.md +++ /dev/null @@ -1,69 +0,0 @@ -# Integration Tests - -Integration tests require a running Prefect server instance. - -## Running Integration Tests - -### 1. Start a Prefect Server - -```bash -prefect server start -``` - -This will start a Prefect server on `http://localhost:4200`. - -### 2. Run the Integration Tests - -```bash -go test -v -tags=integration ./pkg/client/ -``` - -Or using make: - -```bash -make test-integration -``` - -### 3. Custom Server URL - -If your Prefect server is running on a different URL, set the `PREFECT_API_URL` environment variable: - -```bash -export PREFECT_API_URL="http://your-server:4200/api" -go test -v -tags=integration ./pkg/client/ -``` - -## Test Coverage - -The integration tests cover: - -- **Flow Lifecycle**: Create, get, update, list, and delete flows -- **Flow Run Lifecycle**: Create, get, set state, and delete flow runs -- **Deployment Lifecycle**: Create, get, pause, resume, and delete deployments -- **Pagination**: Manual pagination and iterator-based pagination -- **Variables**: Create, get, update, and delete variables -- **Admin Endpoints**: Health check and version -- **Flow Run Monitoring**: Wait for flow run completion - -## Notes - -- Integration tests use the `integration` build tag to separate them from unit tests -- Each test creates and cleans up its own resources -- Tests use unique UUIDs in names to avoid conflicts -- Some tests may take several seconds to complete (e.g., flow run wait tests) - -## Troubleshooting - -### Connection Refused - -If you see "connection refused" errors, make sure: -1. Prefect server is running (`prefect server start`) -2. The server is accessible at the configured URL -3. No firewall is blocking the connection - -### Test Timeout - -If tests timeout: -1. Increase the test timeout: `go test -timeout 5m -tags=integration ./pkg/client/` -2. Check server logs for errors -3. Ensure the server is not under heavy load diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go deleted file mode 100644 index 284fa59..0000000 --- a/pkg/client/integration_test.go +++ /dev/null @@ -1,415 +0,0 @@ -//go:build integration -// +build integration - -package client_test - -import ( - "context" - "os" - "testing" - "time" - - "github.com/google/uuid" - "github.com/gregor/prefect-go/pkg/client" - "github.com/gregor/prefect-go/pkg/errors" - "github.com/gregor/prefect-go/pkg/models" -) - -var testClient *client.Client - -func TestMain(m *testing.M) { - // Get base URL from environment or use default - baseURL := os.Getenv("PREFECT_API_URL") - if baseURL == "" { - baseURL = "http://localhost:4200/api" - } - - // Create test client - var err error - testClient, err = client.NewClient( - client.WithBaseURL(baseURL), - ) - if err != nil { - panic(err) - } - - // Run tests - os.Exit(m.Run()) -} - -func TestIntegration_FlowLifecycle(t *testing.T) { - ctx := context.Background() - - // Create a flow - flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{ - Name: "test-flow-" + uuid.New().String(), - Tags: []string{"integration-test"}, - }) - if err != nil { - t.Fatalf("Failed to create flow: %v", err) - } - defer testClient.Flows.Delete(ctx, flow.ID) - - // Verify flow was created - if flow.ID == uuid.Nil { - t.Error("Flow ID is nil") - } - if flow.Name == "" { - t.Error("Flow name is empty") - } - - // Get the flow - retrievedFlow, err := testClient.Flows.Get(ctx, flow.ID) - if err != nil { - t.Fatalf("Failed to get flow: %v", err) - } - if retrievedFlow.ID != flow.ID { - t.Errorf("Flow ID mismatch: got %v, want %v", retrievedFlow.ID, flow.ID) - } - - // Update the flow - newTags := []string{"integration-test", "updated"} - updatedFlow, err := testClient.Flows.Update(ctx, flow.ID, &models.FlowUpdate{ - Tags: &newTags, - }) - if err != nil { - t.Fatalf("Failed to update flow: %v", err) - } - if len(updatedFlow.Tags) != 2 { - t.Errorf("Expected 2 tags, got %d", len(updatedFlow.Tags)) - } - - // List flows - flowsPage, err := testClient.Flows.List(ctx, &models.FlowFilter{ - Tags: []string{"integration-test"}, - }, 0, 10) - if err != nil { - t.Fatalf("Failed to list flows: %v", err) - } - if len(flowsPage.Results) == 0 { - t.Error("Expected at least one flow in results") - } - - // Delete the flow - if err := testClient.Flows.Delete(ctx, flow.ID); err != nil { - t.Fatalf("Failed to delete flow: %v", err) - } - - // Verify deletion - _, err = testClient.Flows.Get(ctx, flow.ID) - if !errors.IsNotFound(err) { - t.Errorf("Expected NotFound error, got: %v", err) - } -} - -func TestIntegration_FlowRunLifecycle(t *testing.T) { - ctx := context.Background() - - // Create a flow first - flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{ - Name: "test-flow-run-" + uuid.New().String(), - Tags: []string{"integration-test"}, - }) - if err != nil { - t.Fatalf("Failed to create flow: %v", err) - } - defer testClient.Flows.Delete(ctx, flow.ID) - - // Create a flow run - flowRun, err := testClient.FlowRuns.Create(ctx, &models.FlowRunCreate{ - FlowID: flow.ID, - Name: "test-run", - Parameters: map[string]interface{}{ - "test_param": "value", - }, - }) - if err != nil { - t.Fatalf("Failed to create flow run: %v", err) - } - defer testClient.FlowRuns.Delete(ctx, flowRun.ID) - - // Verify flow run was created - if flowRun.ID == uuid.Nil { - t.Error("Flow run ID is nil") - } - if flowRun.FlowID != flow.ID { - t.Errorf("Flow ID mismatch: got %v, want %v", flowRun.FlowID, flow.ID) - } - - // Get the flow run - retrievedRun, err := testClient.FlowRuns.Get(ctx, flowRun.ID) - if err != nil { - t.Fatalf("Failed to get flow run: %v", err) - } - if retrievedRun.ID != flowRun.ID { - t.Errorf("Flow run ID mismatch: got %v, want %v", retrievedRun.ID, flowRun.ID) - } - - // Set state to RUNNING - runningState := models.StateTypeRunning - updatedRun, err := testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{ - Type: runningState, - Message: strPtr("Test running"), - }) - if err != nil { - t.Fatalf("Failed to set state: %v", err) - } - if updatedRun.StateType == nil || *updatedRun.StateType != runningState { - t.Errorf("State type mismatch: got %v, want %v", updatedRun.StateType, runningState) - } - - // Set state to COMPLETED - completedState := models.StateTypeCompleted - completedRun, err := testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{ - Type: completedState, - Message: strPtr("Test completed"), - }) - if err != nil { - t.Fatalf("Failed to set state: %v", err) - } - if completedRun.StateType == nil || *completedRun.StateType != completedState { - t.Errorf("State type mismatch: got %v, want %v", completedRun.StateType, completedState) - } - - // List flow runs - runsPage, err := testClient.FlowRuns.List(ctx, &models.FlowRunFilter{ - FlowID: &flow.ID, - }, 0, 10) - if err != nil { - t.Fatalf("Failed to list flow runs: %v", err) - } - if len(runsPage.Results) == 0 { - t.Error("Expected at least one flow run in results") - } -} - -func TestIntegration_DeploymentLifecycle(t *testing.T) { - ctx := context.Background() - - // Create a flow first - flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{ - Name: "test-deployment-flow-" + uuid.New().String(), - Tags: []string{"integration-test"}, - }) - if err != nil { - t.Fatalf("Failed to create flow: %v", err) - } - defer testClient.Flows.Delete(ctx, flow.ID) - - // Create a deployment - workPoolName := "default-pool" - deployment, err := testClient.Deployments.Create(ctx, &models.DeploymentCreate{ - Name: "test-deployment", - FlowID: flow.ID, - WorkPoolName: &workPoolName, - }) - if err != nil { - t.Fatalf("Failed to create deployment: %v", err) - } - defer testClient.Deployments.Delete(ctx, deployment.ID) - - // Verify deployment was created - if deployment.ID == uuid.Nil { - t.Error("Deployment ID is nil") - } - if deployment.FlowID != flow.ID { - t.Errorf("Flow ID mismatch: got %v, want %v", deployment.FlowID, flow.ID) - } - - // Get the deployment - retrievedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID) - if err != nil { - t.Fatalf("Failed to get deployment: %v", err) - } - if retrievedDeployment.ID != deployment.ID { - t.Errorf("Deployment ID mismatch: got %v, want %v", retrievedDeployment.ID, deployment.ID) - } - - // Pause the deployment - if err := testClient.Deployments.Pause(ctx, deployment.ID); err != nil { - t.Fatalf("Failed to pause deployment: %v", err) - } - - // Verify paused - pausedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID) - if err != nil { - t.Fatalf("Failed to get deployment: %v", err) - } - if !pausedDeployment.Paused { - t.Error("Expected deployment to be paused") - } - - // Resume the deployment - if err := testClient.Deployments.Resume(ctx, deployment.ID); err != nil { - t.Fatalf("Failed to resume deployment: %v", err) - } - - // Verify resumed - resumedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID) - if err != nil { - t.Fatalf("Failed to get deployment: %v", err) - } - if resumedDeployment.Paused { - t.Error("Expected deployment to be resumed") - } -} - -func TestIntegration_Pagination(t *testing.T) { - ctx := context.Background() - - // Create multiple flows - createdFlows := make([]uuid.UUID, 0) - for i := 0; i < 15; i++ { - flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{ - Name: "pagination-test-" + uuid.New().String(), - Tags: []string{"pagination-integration-test"}, - }) - if err != nil { - t.Fatalf("Failed to create flow %d: %v", i, err) - } - createdFlows = append(createdFlows, flow.ID) - } - - // Clean up - defer func() { - for _, id := range createdFlows { - testClient.Flows.Delete(ctx, id) - } - }() - - // Test manual pagination - page1, err := testClient.Flows.List(ctx, &models.FlowFilter{ - Tags: []string{"pagination-integration-test"}, - }, 0, 5) - if err != nil { - t.Fatalf("Failed to list flows: %v", err) - } - if len(page1.Results) != 5 { - t.Errorf("Expected 5 flows in page 1, got %d", len(page1.Results)) - } - if !page1.HasMore { - t.Error("Expected more pages") - } - - // Test iterator - iter := testClient.Flows.ListAll(ctx, &models.FlowFilter{ - Tags: []string{"pagination-integration-test"}, - }) - - count := 0 - for iter.Next(ctx) { - count++ - } - if err := iter.Err(); err != nil { - t.Fatalf("Iterator error: %v", err) - } - if count != 15 { - t.Errorf("Expected to iterate over 15 flows, got %d", count) - } -} - -func TestIntegration_Variables(t *testing.T) { - ctx := context.Background() - - // Create a variable - variable, err := testClient.Variables.Create(ctx, &models.VariableCreate{ - Name: "test-var-" + uuid.New().String(), - Value: "test-value", - Tags: []string{"integration-test"}, - }) - if err != nil { - t.Fatalf("Failed to create variable: %v", err) - } - defer testClient.Variables.Delete(ctx, variable.ID) - - // Get the variable - retrievedVar, err := testClient.Variables.Get(ctx, variable.ID) - if err != nil { - t.Fatalf("Failed to get variable: %v", err) - } - if retrievedVar.Value != "test-value" { - t.Errorf("Value mismatch: got %v, want test-value", retrievedVar.Value) - } - - // Update the variable - newValue := "updated-value" - updatedVar, err := testClient.Variables.Update(ctx, variable.ID, &models.VariableUpdate{ - Value: &newValue, - }) - if err != nil { - t.Fatalf("Failed to update variable: %v", err) - } - if updatedVar.Value != newValue { - t.Errorf("Value mismatch: got %v, want %v", updatedVar.Value, newValue) - } -} - -func TestIntegration_AdminEndpoints(t *testing.T) { - ctx := context.Background() - - // Test health check - if err := testClient.Admin.Health(ctx); err != nil { - t.Fatalf("Health check failed: %v", err) - } - - // Test version - version, err := testClient.Admin.Version(ctx) - if err != nil { - t.Fatalf("Failed to get version: %v", err) - } - if version == "" { - t.Error("Version is empty") - } - t.Logf("Server version: %s", version) -} - -func TestIntegration_FlowRunWait(t *testing.T) { - ctx := context.Background() - - // Create a flow - flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{ - Name: "wait-test-" + uuid.New().String(), - Tags: []string{"integration-test"}, - }) - if err != nil { - t.Fatalf("Failed to create flow: %v", err) - } - defer testClient.Flows.Delete(ctx, flow.ID) - - // Create a flow run - flowRun, err := testClient.FlowRuns.Create(ctx, &models.FlowRunCreate{ - FlowID: flow.ID, - Name: "wait-test-run", - }) - if err != nil { - t.Fatalf("Failed to create flow run: %v", err) - } - defer testClient.FlowRuns.Delete(ctx, flowRun.ID) - - // Simulate completion in background - go func() { - time.Sleep(2 * time.Second) - completedState := models.StateTypeCompleted - testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{ - Type: completedState, - Message: strPtr("Test completed"), - }) - }() - - // Wait for completion with timeout - waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - - finalRun, err := testClient.FlowRuns.Wait(waitCtx, flowRun.ID, time.Second) - if err != nil { - t.Fatalf("Failed to wait for flow run: %v", err) - } - - if finalRun.StateType == nil || *finalRun.StateType != models.StateTypeCompleted { - t.Errorf("Expected COMPLETED state, got %v", finalRun.StateType) - } -} - -func strPtr(s string) *string { - return &s -} diff --git a/pkg/errors/errors_test.go b/pkg/errors/errors_test.go deleted file mode 100644 index e8a5175..0000000 --- a/pkg/errors/errors_test.go +++ /dev/null @@ -1,240 +0,0 @@ -package errors - -import ( - "io" - "net/http" - "strings" - "testing" -) - -func TestAPIError_Error(t *testing.T) { - tests := []struct { - name string - err *APIError - want string - }{ - { - name: "with message", - err: &APIError{ - StatusCode: 404, - Message: "Flow not found", - }, - want: "prefect api error (status 404): Flow not found", - }, - { - name: "without message", - err: &APIError{ - StatusCode: 500, - }, - want: "prefect api error (status 500)", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.err.Error(); got != tt.want { - t.Errorf("Error() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestAPIError_Checks(t *testing.T) { - tests := []struct { - name string - statusCode int - checkFuncs map[string]func(*APIError) bool - want map[string]bool - }{ - { - name: "404 not found", - statusCode: 404, - checkFuncs: map[string]func(*APIError) bool{ - "IsNotFound": (*APIError).IsNotFound, - "IsUnauthorized": (*APIError).IsUnauthorized, - "IsForbidden": (*APIError).IsForbidden, - "IsRateLimited": (*APIError).IsRateLimited, - "IsServerError": (*APIError).IsServerError, - }, - want: map[string]bool{ - "IsNotFound": true, - "IsUnauthorized": false, - "IsForbidden": false, - "IsRateLimited": false, - "IsServerError": false, - }, - }, - { - name: "401 unauthorized", - statusCode: 401, - checkFuncs: map[string]func(*APIError) bool{ - "IsNotFound": (*APIError).IsNotFound, - "IsUnauthorized": (*APIError).IsUnauthorized, - "IsForbidden": (*APIError).IsForbidden, - "IsRateLimited": (*APIError).IsRateLimited, - "IsServerError": (*APIError).IsServerError, - }, - want: map[string]bool{ - "IsNotFound": false, - "IsUnauthorized": true, - "IsForbidden": false, - "IsRateLimited": false, - "IsServerError": false, - }, - }, - { - name: "500 server error", - statusCode: 500, - checkFuncs: map[string]func(*APIError) bool{ - "IsNotFound": (*APIError).IsNotFound, - "IsUnauthorized": (*APIError).IsUnauthorized, - "IsForbidden": (*APIError).IsForbidden, - "IsRateLimited": (*APIError).IsRateLimited, - "IsServerError": (*APIError).IsServerError, - }, - want: map[string]bool{ - "IsNotFound": false, - "IsUnauthorized": false, - "IsForbidden": false, - "IsRateLimited": false, - "IsServerError": true, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := &APIError{StatusCode: tt.statusCode} - for checkName, checkFunc := range tt.checkFuncs { - if got := checkFunc(err); got != tt.want[checkName] { - t.Errorf("%s() = %v, want %v", checkName, got, tt.want[checkName]) - } - } - }) - } -} - -func TestNewAPIError(t *testing.T) { - tests := []struct { - name string - resp *http.Response - wantStatus int - wantMsg string - }{ - { - name: "nil response", - resp: nil, - wantStatus: 0, - wantMsg: "nil response", - }, - { - name: "404 with JSON body", - resp: &http.Response{ - StatusCode: 404, - Body: io.NopCloser(strings.NewReader(`{"detail": "Flow not found"}`)), - Header: http.Header{}, - }, - wantStatus: 404, - wantMsg: "Flow not found", - }, - { - name: "500 without body", - resp: &http.Response{ - StatusCode: 500, - Body: io.NopCloser(strings.NewReader("")), - Header: http.Header{}, - }, - wantStatus: 500, - wantMsg: "Internal Server Error", - }, - { - name: "with request ID", - resp: &http.Response{ - StatusCode: 400, - Body: io.NopCloser(strings.NewReader(`{"detail": "Bad request"}`)), - Header: http.Header{ - "X-Request-ID": []string{"req-123"}, - }, - }, - wantStatus: 400, - wantMsg: "Bad request", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := NewAPIError(tt.resp) - apiErr, ok := err.(*APIError) - if !ok { - t.Fatalf("expected *APIError, got %T", err) - } - if apiErr.StatusCode != tt.wantStatus { - t.Errorf("StatusCode = %v, want %v", apiErr.StatusCode, tt.wantStatus) - } - if apiErr.Message != tt.wantMsg { - t.Errorf("Message = %v, want %v", apiErr.Message, tt.wantMsg) - } - }) - } -} - -func TestNewAPIError_Validation(t *testing.T) { - resp := &http.Response{ - StatusCode: 422, - Body: io.NopCloser(strings.NewReader(`{ - "detail": [ - { - "loc": ["body", "name"], - "msg": "field required", - "type": "value_error.missing" - } - ] - }`)), - Header: http.Header{}, - } - - err := NewAPIError(resp) - valErr, ok := err.(*ValidationError) - if !ok { - t.Fatalf("expected *ValidationError, got %T", err) - } - - if valErr.StatusCode != 422 { - t.Errorf("StatusCode = %v, want 422", valErr.StatusCode) - } - - if len(valErr.ValidationErrors) != 1 { - t.Fatalf("expected 1 validation error, got %d", len(valErr.ValidationErrors)) - } - - vd := valErr.ValidationErrors[0] - if len(vd.Loc) != 2 || vd.Loc[0] != "body" || vd.Loc[1] != "name" { - t.Errorf("Loc = %v, want [body name]", vd.Loc) - } - if vd.Msg != "field required" { - t.Errorf("Msg = %v, want 'field required'", vd.Msg) - } - if vd.Type != "value_error.missing" { - t.Errorf("Type = %v, want 'value_error.missing'", vd.Type) - } -} - -func TestHelperFunctions(t *testing.T) { - notFoundErr := &APIError{StatusCode: 404} - unauthorizedErr := &APIError{StatusCode: 401} - serverErr := &APIError{StatusCode: 500} - valErr := &ValidationError{APIError: &APIError{StatusCode: 422}} - - if !IsNotFound(notFoundErr) { - t.Error("IsNotFound should return true for 404 error") - } - if !IsUnauthorized(unauthorizedErr) { - t.Error("IsUnauthorized should return true for 401 error") - } - if !IsServerError(serverErr) { - t.Error("IsServerError should return true for 500 error") - } - if !IsValidationError(valErr) { - t.Error("IsValidationError should return true for validation error") - } -} diff --git a/pkg/pagination/pagination_test.go b/pkg/pagination/pagination_test.go deleted file mode 100644 index 29d69e9..0000000 --- a/pkg/pagination/pagination_test.go +++ /dev/null @@ -1,339 +0,0 @@ -package pagination - -import ( - "context" - "errors" - "testing" -) - -func TestNewIterator(t *testing.T) { - fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) { - return nil, nil - } - - tests := []struct { - name string - limit int - wantLimit int - }{ - {"default limit", 0, 100}, - {"custom limit", 50, 50}, - {"negative limit", -10, 100}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - iter := NewIterator(fetchFunc, tt.limit) - if iter.limit != tt.wantLimit { - t.Errorf("limit = %d, want %d", iter.limit, tt.wantLimit) - } - }) - } -} - -func TestIterator_Next(t *testing.T) { - items := []string{"item1", "item2", "item3", "item4", "item5"} - - fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) { - if offset >= len(items) { - return &PaginatedResponse[string]{ - Results: []string{}, - Count: len(items), - Limit: limit, - Offset: offset, - HasMore: false, - }, nil - } - - end := offset + limit - if end > len(items) { - end = len(items) - } - - return &PaginatedResponse[string]{ - Results: items[offset:end], - Count: len(items), - Limit: limit, - Offset: offset, - HasMore: end < len(items), - }, nil - } - - ctx := context.Background() - iter := NewIterator(fetchFunc, 2) - - var results []string - for iter.Next(ctx) { - if item := iter.Value(); item != nil { - results = append(results, *item) - } - } - - if err := iter.Err(); err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(results) != len(items) { - t.Errorf("got %d items, want %d", len(results), len(items)) - } - - for i, want := range items { - if i >= len(results) || results[i] != want { - t.Errorf("item[%d] = %v, want %v", i, results[i], want) - } - } -} - -func TestIterator_Error(t *testing.T) { - expectedErr := errors.New("fetch error") - - fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) { - return nil, expectedErr - } - - ctx := context.Background() - iter := NewIterator(fetchFunc, 10) - - if iter.Next(ctx) { - t.Error("Next should return false on error") - } - - if err := iter.Err(); err != expectedErr { - t.Errorf("Err() = %v, want %v", err, expectedErr) - } -} - -func TestIterator_EmptyResults(t *testing.T) { - fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) { - return &PaginatedResponse[string]{ - Results: []string{}, - Count: 0, - Limit: limit, - Offset: offset, - HasMore: false, - }, nil - } - - ctx := context.Background() - iter := NewIterator(fetchFunc, 10) - - if iter.Next(ctx) { - t.Error("Next should return false for empty results") - } -} - -func TestIterator_Reset(t *testing.T) { - items := []string{"item1", "item2", "item3"} - - fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) { - if offset >= len(items) { - return &PaginatedResponse[string]{ - Results: []string{}, - HasMore: false, - }, nil - } - - return &PaginatedResponse[string]{ - Results: items[offset:], - HasMore: false, - }, nil - } - - ctx := context.Background() - iter := NewIterator(fetchFunc, 10) - - // First iteration - count1 := 0 - for iter.Next(ctx) { - count1++ - } - - // Reset and iterate again - iter.Reset() - count2 := 0 - for iter.Next(ctx) { - count2++ - } - - if count1 != count2 { - t.Errorf("counts don't match after reset: %d vs %d", count1, count2) - } - if count1 != len(items) { - t.Errorf("got %d items, want %d", count1, len(items)) - } -} - -func TestIterator_Collect(t *testing.T) { - items := []string{"item1", "item2", "item3", "item4", "item5"} - - fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) { - if offset >= len(items) { - return &PaginatedResponse[string]{ - Results: []string{}, - HasMore: false, - }, nil - } - - end := offset + limit - if end > len(items) { - end = len(items) - } - - return &PaginatedResponse[string]{ - Results: items[offset:end], - HasMore: end < len(items), - }, nil - } - - ctx := context.Background() - iter := NewIterator(fetchFunc, 2) - - results, err := iter.Collect(ctx) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(results) != len(items) { - t.Errorf("got %d items, want %d", len(results), len(items)) - } -} - -func TestIterator_CollectWithLimit(t *testing.T) { - items := []string{"item1", "item2", "item3", "item4", "item5"} - - fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) { - if offset >= len(items) { - return &PaginatedResponse[string]{ - Results: []string{}, - HasMore: false, - }, nil - } - - end := offset + limit - if end > len(items) { - end = len(items) - } - - return &PaginatedResponse[string]{ - Results: items[offset:end], - HasMore: end < len(items), - }, nil - } - - ctx := context.Background() - iter := NewIterator(fetchFunc, 2) - - maxItems := 3 - results, err := iter.CollectWithLimit(ctx, maxItems) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(results) != maxItems { - t.Errorf("got %d items, want %d", len(results), maxItems) - } -} - -func TestPage_HasNextPage(t *testing.T) { - tests := []struct { - name string - page Page[string] - want bool - }{ - { - name: "has next page", - page: Page[string]{ - Items: []string{"a", "b"}, - Offset: 0, - Limit: 2, - Total: 5, - }, - want: true, - }, - { - name: "no next page", - page: Page[string]{ - Items: []string{"d", "e"}, - Offset: 3, - Limit: 2, - Total: 5, - }, - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.page.HasNextPage(); got != tt.want { - t.Errorf("HasNextPage() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestPage_HasPrevPage(t *testing.T) { - tests := []struct { - name string - page Page[string] - want bool - }{ - { - name: "has previous page", - page: Page[string]{ - Offset: 2, - }, - want: true, - }, - { - name: "no previous page", - page: Page[string]{ - Offset: 0, - }, - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.page.HasPrevPage(); got != tt.want { - t.Errorf("HasPrevPage() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestPage_NextOffset(t *testing.T) { - page := Page[string]{ - Offset: 10, - Limit: 5, - } - - if got := page.NextOffset(); got != 15 { - t.Errorf("NextOffset() = %v, want 15", got) - } -} - -func TestPage_PrevOffset(t *testing.T) { - tests := []struct { - name string - offset int - limit int - want int - }{ - {"normal case", 10, 5, 5}, - {"at beginning", 3, 5, 0}, - {"exact boundary", 5, 5, 0}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - page := Page[string]{ - Offset: tt.offset, - Limit: tt.limit, - } - if got := page.PrevOffset(); got != tt.want { - t.Errorf("PrevOffset() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/pkg/retry/retry_test.go b/pkg/retry/retry_test.go deleted file mode 100644 index 403211b..0000000 --- a/pkg/retry/retry_test.go +++ /dev/null @@ -1,219 +0,0 @@ -package retry - -import ( - "context" - "net/http" - "testing" - "time" -) - -func TestNew(t *testing.T) { - tests := []struct { - name string - config Config - want Config - }{ - { - name: "default config", - config: Config{}, - want: DefaultConfig(), - }, - { - name: "custom config", - config: Config{ - MaxAttempts: 5, - InitialDelay: 2 * time.Second, - MaxDelay: 60 * time.Second, - BackoffFactor: 3.0, - RetryableErrors: []int{500}, - }, - want: Config{ - MaxAttempts: 5, - InitialDelay: 2 * time.Second, - MaxDelay: 60 * time.Second, - BackoffFactor: 3.0, - RetryableErrors: []int{500}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := New(tt.config) - if r.config.MaxAttempts != tt.want.MaxAttempts { - t.Errorf("MaxAttempts = %v, want %v", r.config.MaxAttempts, tt.want.MaxAttempts) - } - if r.config.InitialDelay != tt.want.InitialDelay { - t.Errorf("InitialDelay = %v, want %v", r.config.InitialDelay, tt.want.InitialDelay) - } - }) - } -} - -func TestRetrier_isRetryable(t *testing.T) { - r := New(DefaultConfig()) - - tests := []struct { - name string - statusCode int - want bool - }{ - {"429 rate limit", 429, true}, - {"500 server error", 500, true}, - {"502 bad gateway", 502, true}, - {"503 service unavailable", 503, true}, - {"504 gateway timeout", 504, true}, - {"200 ok", 200, false}, - {"400 bad request", 400, false}, - {"404 not found", 404, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - resp := &http.Response{StatusCode: tt.statusCode} - if got := r.isRetryable(resp); got != tt.want { - t.Errorf("isRetryable() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestRetrier_Do_Success(t *testing.T) { - r := New(DefaultConfig()) - ctx := context.Background() - - attempts := 0 - fn := func() (*http.Response, error) { - attempts++ - return &http.Response{StatusCode: 200}, nil - } - - resp, err := r.Do(ctx, fn) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.StatusCode != 200 { - t.Errorf("status = %d, want 200", resp.StatusCode) - } - if attempts != 1 { - t.Errorf("attempts = %d, want 1", attempts) - } -} - -func TestRetrier_Do_Retry(t *testing.T) { - config := Config{ - MaxAttempts: 3, - InitialDelay: 10 * time.Millisecond, - MaxDelay: 100 * time.Millisecond, - BackoffFactor: 2.0, - RetryableErrors: []int{500}, - } - r := New(config) - ctx := context.Background() - - attempts := 0 - fn := func() (*http.Response, error) { - attempts++ - if attempts < 3 { - return &http.Response{StatusCode: 500}, nil - } - return &http.Response{StatusCode: 200}, nil - } - - start := time.Now() - resp, err := r.Do(ctx, fn) - elapsed := time.Since(start) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if resp.StatusCode != 200 { - t.Errorf("status = %d, want 200", resp.StatusCode) - } - if attempts != 3 { - t.Errorf("attempts = %d, want 3", attempts) - } - // Should have waited at least once - if elapsed < 10*time.Millisecond { - t.Errorf("elapsed time %v too short, expected at least 10ms", elapsed) - } -} - -func TestRetrier_Do_ContextCancellation(t *testing.T) { - config := Config{ - MaxAttempts: 5, - InitialDelay: 100 * time.Millisecond, - MaxDelay: 1 * time.Second, - BackoffFactor: 2.0, - RetryableErrors: []int{500}, - } - r := New(config) - - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - - attempts := 0 - fn := func() (*http.Response, error) { - attempts++ - return &http.Response{StatusCode: 500}, nil - } - - _, err := r.Do(ctx, fn) - if err != context.DeadlineExceeded { - t.Errorf("error = %v, want context.DeadlineExceeded", err) - } - if attempts > 2 { - t.Errorf("attempts = %d, should be cancelled early", attempts) - } -} - -func TestRetrier_calculateDelay(t *testing.T) { - config := Config{ - InitialDelay: time.Second, - MaxDelay: 30 * time.Second, - BackoffFactor: 2.0, - } - r := New(config) - - tests := []struct { - name string - attempt int - resp *http.Response - minWant time.Duration - maxWant time.Duration - }{ - { - name: "first retry", - attempt: 1, - resp: &http.Response{StatusCode: 500}, - minWant: 500 * time.Millisecond, // with jitter min 0.5 - maxWant: 1 * time.Second, // with jitter max 1.0 - }, - { - name: "second retry", - attempt: 2, - resp: &http.Response{StatusCode: 500}, - minWant: 1 * time.Second, // 2^1 * 1s * 0.5 - maxWant: 2 * time.Second, // 2^1 * 1s * 1.0 - }, - { - name: "with Retry-After header", - attempt: 1, - resp: &http.Response{ - StatusCode: 429, - Header: http.Header{"Retry-After": []string{"5"}}, - }, - minWant: 5 * time.Second, - maxWant: 5 * time.Second, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - delay := r.calculateDelay(tt.attempt, tt.resp) - if delay < tt.minWant || delay > tt.maxWant { - t.Errorf("delay = %v, want between %v and %v", delay, tt.minWant, tt.maxWant) - } - }) - } -}