Initial commit

This commit is contained in:
Gregor Schulte
2026-02-02 08:41:48 +01:00
commit 43b4910a63
30 changed files with 4898 additions and 0 deletions

View File

@@ -0,0 +1,69 @@
# Integration Tests
Integration tests require a running Prefect server instance.
## Running Integration Tests
### 1. Start a Prefect Server
```bash
prefect server start
```
This will start a Prefect server on `http://localhost:4200`.
### 2. Run the Integration Tests
```bash
go test -v -tags=integration ./pkg/client/
```
Or using make:
```bash
make test-integration
```
### 3. Custom Server URL
If your Prefect server is running on a different URL, set the `PREFECT_API_URL` environment variable:
```bash
export PREFECT_API_URL="http://your-server:4200/api"
go test -v -tags=integration ./pkg/client/
```
## Test Coverage
The integration tests cover:
- **Flow Lifecycle**: Create, get, update, list, and delete flows
- **Flow Run Lifecycle**: Create, get, set state, and delete flow runs
- **Deployment Lifecycle**: Create, get, pause, resume, and delete deployments
- **Pagination**: Manual pagination and iterator-based pagination
- **Variables**: Create, get, update, and delete variables
- **Admin Endpoints**: Health check and version
- **Flow Run Monitoring**: Wait for flow run completion
## Notes
- Integration tests use the `integration` build tag to separate them from unit tests
- Each test creates and cleans up its own resources
- Tests use unique UUIDs in names to avoid conflicts
- Some tests may take several seconds to complete (e.g., flow run wait tests)
## Troubleshooting
### Connection Refused
If you see "connection refused" errors, make sure:
1. Prefect server is running (`prefect server start`)
2. The server is accessible at the configured URL
3. No firewall is blocking the connection
### Test Timeout
If tests timeout:
1. Increase the test timeout: `go test -timeout 5m -tags=integration ./pkg/client/`
2. Check server logs for errors
3. Ensure the server is not under heavy load

268
pkg/client/client.go Normal file
View File

@@ -0,0 +1,268 @@
// Package client provides a Go client for the Prefect Server API.
package client
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/gregor/prefect-go/pkg/errors"
"github.com/gregor/prefect-go/pkg/retry"
)
const (
defaultBaseURL = "http://localhost:4200/api"
defaultTimeout = 30 * time.Second
userAgent = "prefect-go-client/1.0.0"
)
// Client is the main client for interacting with the Prefect API.
type Client struct {
baseURL *url.URL
httpClient *http.Client
apiKey string
retrier *retry.Retrier
userAgent string
// Services
Flows *FlowsService
FlowRuns *FlowRunsService
Deployments *DeploymentsService
TaskRuns *TaskRunsService
WorkPools *WorkPoolsService
WorkQueues *WorkQueuesService
Variables *VariablesService
Logs *LogsService
Admin *AdminService
}
// Option is a functional option for configuring the client.
type Option func(*Client) error
// NewClient creates a new Prefect API client with the given options.
func NewClient(opts ...Option) (*Client, error) {
// Parse default base URL
baseURL, err := url.Parse(defaultBaseURL)
if err != nil {
return nil, fmt.Errorf("failed to parse default base URL: %w", err)
}
c := &Client{
baseURL: baseURL,
httpClient: &http.Client{
Timeout: defaultTimeout,
},
retrier: retry.New(retry.DefaultConfig()),
userAgent: userAgent,
}
// Apply options
for _, opt := range opts {
if err := opt(c); err != nil {
return nil, fmt.Errorf("failed to apply option: %w", err)
}
}
// Initialize services
c.Flows = &FlowsService{client: c}
c.FlowRuns = &FlowRunsService{client: c}
c.Deployments = &DeploymentsService{client: c}
c.TaskRuns = &TaskRunsService{client: c}
c.WorkPools = &WorkPoolsService{client: c}
c.WorkQueues = &WorkQueuesService{client: c}
c.Variables = &VariablesService{client: c}
c.Logs = &LogsService{client: c}
c.Admin = &AdminService{client: c}
return c, nil
}
// WithBaseURL sets the base URL for the Prefect API.
func WithBaseURL(baseURL string) Option {
return func(c *Client) error {
u, err := url.Parse(baseURL)
if err != nil {
return fmt.Errorf("invalid base URL: %w", err)
}
c.baseURL = u
return nil
}
}
// WithAPIKey sets the API key for authentication with Prefect Cloud.
func WithAPIKey(apiKey string) Option {
return func(c *Client) error {
c.apiKey = apiKey
return nil
}
}
// WithHTTPClient sets a custom HTTP client.
func WithHTTPClient(httpClient *http.Client) Option {
return func(c *Client) error {
if httpClient == nil {
return fmt.Errorf("HTTP client cannot be nil")
}
c.httpClient = httpClient
return nil
}
}
// WithTimeout sets the HTTP client timeout.
func WithTimeout(timeout time.Duration) Option {
return func(c *Client) error {
c.httpClient.Timeout = timeout
return nil
}
}
// WithRetry sets the retry configuration.
func WithRetry(config retry.Config) Option {
return func(c *Client) error {
c.retrier = retry.New(config)
return nil
}
}
// WithUserAgent sets a custom user agent string.
func WithUserAgent(ua string) Option {
return func(c *Client) error {
c.userAgent = ua
return nil
}
}
// do executes an HTTP request with retry logic.
func (c *Client) do(ctx context.Context, method, path string, body, result interface{}) error {
// Build full URL
u, err := c.baseURL.Parse(path)
if err != nil {
return fmt.Errorf("failed to parse path: %w", err)
}
// Serialize request body if present
var reqBody io.Reader
if body != nil {
data, err := json.Marshal(body)
if err != nil {
return fmt.Errorf("failed to marshal request body: %w", err)
}
reqBody = bytes.NewReader(data)
}
// Create request
req, err := http.NewRequestWithContext(ctx, method, u.String(), reqBody)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", c.userAgent)
// Set API key if present
if c.apiKey != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
}
// Execute request with retry logic
var resp *http.Response
resp, err = c.retrier.Do(ctx, func() (*http.Response, error) {
// Clone the request for retry attempts
reqClone := req.Clone(ctx)
if reqBody != nil {
// Reset body for retries
if seeker, ok := reqBody.(io.Seeker); ok {
if _, err := seeker.Seek(0, 0); err != nil {
return nil, fmt.Errorf("failed to reset request body: %w", err)
}
}
reqClone.Body = io.NopCloser(reqBody)
}
return c.httpClient.Do(reqClone)
})
if err != nil {
return fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
// Check for error status codes
if resp.StatusCode >= 400 {
return errors.NewAPIError(resp)
}
// Deserialize response body if result is provided
if result != nil && resp.StatusCode != http.StatusNoContent {
if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
return fmt.Errorf("failed to decode response: %w", err)
}
}
return nil
}
// get performs a GET request.
func (c *Client) get(ctx context.Context, path string, result interface{}) error {
return c.do(ctx, http.MethodGet, path, nil, result)
}
// post performs a POST request.
func (c *Client) post(ctx context.Context, path string, body, result interface{}) error {
return c.do(ctx, http.MethodPost, path, body, result)
}
// patch performs a PATCH request.
func (c *Client) patch(ctx context.Context, path string, body, result interface{}) error {
return c.do(ctx, http.MethodPatch, path, body, result)
}
// delete performs a DELETE request.
func (c *Client) delete(ctx context.Context, path string) error {
return c.do(ctx, http.MethodDelete, path, nil, nil)
}
// buildPath builds a URL path with optional query parameters.
func buildPath(base string, params map[string]string) string {
if len(params) == 0 {
return base
}
query := url.Values{}
for k, v := range params {
if v != "" {
query.Set(k, v)
}
}
if queryStr := query.Encode(); queryStr != "" {
return base + "?" + queryStr
}
return base
}
// buildPathWithValues builds a URL path with url.Values query parameters.
func buildPathWithValues(base string, query url.Values) string {
if len(query) == 0 {
return base
}
if queryStr := query.Encode(); queryStr != "" {
return base + "?" + queryStr
}
return base
}
// joinPath joins URL path segments.
func joinPath(parts ...string) string {
return strings.Join(parts, "/")
}

252
pkg/client/client_test.go Normal file
View File

@@ -0,0 +1,252 @@
package client
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestNewClient(t *testing.T) {
tests := []struct {
name string
opts []Option
wantErr bool
}{
{
name: "default client",
opts: nil,
wantErr: false,
},
{
name: "with base URL",
opts: []Option{
WithBaseURL("http://example.com/api"),
},
wantErr: false,
},
{
name: "with invalid base URL",
opts: []Option{
WithBaseURL("://invalid"),
},
wantErr: true,
},
{
name: "with API key",
opts: []Option{
WithAPIKey("test-key"),
},
wantErr: false,
},
{
name: "with custom timeout",
opts: []Option{
WithTimeout(60 * time.Second),
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, err := NewClient(tt.opts...)
if (err != nil) != tt.wantErr {
t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && client == nil {
t.Error("NewClient() returned nil client")
}
})
}
}
func TestClient_do(t *testing.T) {
// Create a test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check headers
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("Content-Type = %v, want application/json", r.Header.Get("Content-Type"))
}
if r.Header.Get("Accept") != "application/json" {
t.Errorf("Accept = %v, want application/json", r.Header.Get("Accept"))
}
// Check path
if r.URL.Path == "/test" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"message": "success"}`))
} else if r.URL.Path == "/error" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"detail": "bad request"}`))
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
defer server.Close()
client, err := NewClient(WithBaseURL(server.URL))
if err != nil {
t.Fatalf("failed to create client: %v", err)
}
ctx := context.Background()
t.Run("successful request", func(t *testing.T) {
var result map[string]string
err := client.get(ctx, "/test", &result)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if result["message"] != "success" {
t.Errorf("message = %v, want success", result["message"])
}
})
t.Run("error response", func(t *testing.T) {
var result map[string]string
err := client.get(ctx, "/error", &result)
if err == nil {
t.Error("expected error, got nil")
}
})
}
func TestClient_WithAPIKey(t *testing.T) {
apiKey := "test-api-key"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
expected := "Bearer " + apiKey
if auth != expected {
t.Errorf("Authorization = %v, want %v", auth, expected)
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client, err := NewClient(
WithBaseURL(server.URL),
WithAPIKey(apiKey),
)
if err != nil {
t.Fatalf("failed to create client: %v", err)
}
ctx := context.Background()
err = client.get(ctx, "/test", nil)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
func TestClient_WithCustomHTTPClient(t *testing.T) {
customClient := &http.Client{
Timeout: 10 * time.Second,
}
client, err := NewClient(WithHTTPClient(customClient))
if err != nil {
t.Fatalf("failed to create client: %v", err)
}
if client.httpClient != customClient {
t.Error("custom HTTP client not set")
}
}
func TestClient_WithNilHTTPClient(t *testing.T) {
_, err := NewClient(WithHTTPClient(nil))
if err == nil {
t.Error("expected error with nil HTTP client")
}
}
func TestBuildPath(t *testing.T) {
tests := []struct {
name string
base string
params map[string]string
want string
}{
{
name: "no params",
base: "/flows",
params: nil,
want: "/flows",
},
{
name: "with params",
base: "/flows",
params: map[string]string{
"limit": "10",
"offset": "20",
},
want: "/flows?limit=10&offset=20",
},
{
name: "with empty param value",
base: "/flows",
params: map[string]string{
"limit": "10",
"offset": "",
},
want: "/flows?limit=10",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := buildPath(tt.base, tt.params)
// Note: URL encoding might change order, so we just check it contains the params
if !contains(got, tt.base) {
t.Errorf("buildPath() = %v, should contain %v", got, tt.base)
}
})
}
}
func TestJoinPath(t *testing.T) {
tests := []struct {
name string
parts []string
want string
}{
{
name: "single part",
parts: []string{"flows"},
want: "flows",
},
{
name: "multiple parts",
parts: []string{"flows", "123", "runs"},
want: "flows/123/runs",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := joinPath(tt.parts...)
if got != tt.want {
t.Errorf("joinPath() = %v, want %v", got, tt.want)
}
})
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || containsMiddle(s, substr)))
}
func containsMiddle(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

157
pkg/client/flow_runs.go Normal file
View File

@@ -0,0 +1,157 @@
package client
import (
"context"
"fmt"
"time"
"github.com/google/uuid"
"github.com/gregor/prefect-go/pkg/models"
"github.com/gregor/prefect-go/pkg/pagination"
)
// FlowRunsService handles operations related to flow runs.
type FlowRunsService struct {
client *Client
}
// Create creates a new flow run.
func (s *FlowRunsService) Create(ctx context.Context, req *models.FlowRunCreate) (*models.FlowRun, error) {
var flowRun models.FlowRun
if err := s.client.post(ctx, "/flow_runs/", req, &flowRun); err != nil {
return nil, fmt.Errorf("failed to create flow run: %w", err)
}
return &flowRun, nil
}
// Get retrieves a flow run by ID.
func (s *FlowRunsService) Get(ctx context.Context, id uuid.UUID) (*models.FlowRun, error) {
var flowRun models.FlowRun
path := joinPath("/flow_runs", id.String())
if err := s.client.get(ctx, path, &flowRun); err != nil {
return nil, fmt.Errorf("failed to get flow run: %w", err)
}
return &flowRun, nil
}
// List retrieves a list of flow runs with optional filtering.
func (s *FlowRunsService) List(ctx context.Context, filter *models.FlowRunFilter, offset, limit int) (*pagination.PaginatedResponse[models.FlowRun], error) {
if filter == nil {
filter = &models.FlowRunFilter{}
}
filter.Offset = offset
filter.Limit = limit
type response struct {
Results []models.FlowRun `json:"results"`
Count int `json:"count"`
}
var resp response
if err := s.client.post(ctx, "/flow_runs/filter", filter, &resp); err != nil {
return nil, fmt.Errorf("failed to list flow runs: %w", err)
}
return &pagination.PaginatedResponse[models.FlowRun]{
Results: resp.Results,
Count: resp.Count,
Limit: limit,
Offset: offset,
HasMore: offset+len(resp.Results) < resp.Count,
}, nil
}
// ListAll returns an iterator for all flow runs matching the filter.
func (s *FlowRunsService) ListAll(ctx context.Context, filter *models.FlowRunFilter) *pagination.Iterator[models.FlowRun] {
fetchFunc := func(ctx context.Context, offset, limit int) (*pagination.PaginatedResponse[models.FlowRun], error) {
return s.List(ctx, filter, offset, limit)
}
return pagination.NewIterator(fetchFunc, 100)
}
// Update updates a flow run.
func (s *FlowRunsService) Update(ctx context.Context, id uuid.UUID, req *models.FlowRunUpdate) (*models.FlowRun, error) {
var flowRun models.FlowRun
path := joinPath("/flow_runs", id.String())
if err := s.client.patch(ctx, path, req, &flowRun); err != nil {
return nil, fmt.Errorf("failed to update flow run: %w", err)
}
return &flowRun, nil
}
// Delete deletes a flow run by ID.
func (s *FlowRunsService) Delete(ctx context.Context, id uuid.UUID) error {
path := joinPath("/flow_runs", id.String())
if err := s.client.delete(ctx, path); err != nil {
return fmt.Errorf("failed to delete flow run: %w", err)
}
return nil
}
// SetState sets the state of a flow run.
func (s *FlowRunsService) SetState(ctx context.Context, id uuid.UUID, state *models.StateCreate) (*models.FlowRun, error) {
var flowRun models.FlowRun
path := joinPath("/flow_runs", id.String(), "set_state")
req := struct {
State *models.StateCreate `json:"state"`
}{
State: state,
}
if err := s.client.post(ctx, path, req, &flowRun); err != nil {
return nil, fmt.Errorf("failed to set flow run state: %w", err)
}
return &flowRun, nil
}
// Wait waits for a flow run to complete by polling its state.
// It returns the final flow run state or an error if the context is cancelled.
func (s *FlowRunsService) Wait(ctx context.Context, id uuid.UUID, pollInterval time.Duration) (*models.FlowRun, error) {
if pollInterval <= 0 {
pollInterval = 5 * time.Second
}
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
flowRun, err := s.Get(ctx, id)
if err != nil {
return nil, err
}
// Check if flow run is in a terminal state
if flowRun.StateType != nil {
switch *flowRun.StateType {
case models.StateTypeCompleted,
models.StateTypeFailed,
models.StateTypeCancelled,
models.StateTypeCrashed:
return flowRun, nil
}
}
}
}
}
// Count returns the number of flow runs matching the filter.
func (s *FlowRunsService) Count(ctx context.Context, filter *models.FlowRunFilter) (int, error) {
if filter == nil {
filter = &models.FlowRunFilter{}
}
var result struct {
Count int `json:"count"`
}
if err := s.client.post(ctx, "/flow_runs/count", filter, &result); err != nil {
return 0, fmt.Errorf("failed to count flow runs: %w", err)
}
return result.Count, nil
}

115
pkg/client/flows.go Normal file
View File

@@ -0,0 +1,115 @@
package client
import (
"context"
"fmt"
"github.com/google/uuid"
"github.com/gregor/prefect-go/pkg/models"
"github.com/gregor/prefect-go/pkg/pagination"
)
// FlowsService handles operations related to flows.
type FlowsService struct {
client *Client
}
// Create creates a new flow.
func (s *FlowsService) Create(ctx context.Context, req *models.FlowCreate) (*models.Flow, error) {
var flow models.Flow
if err := s.client.post(ctx, "/flows/", req, &flow); err != nil {
return nil, fmt.Errorf("failed to create flow: %w", err)
}
return &flow, nil
}
// Get retrieves a flow by ID.
func (s *FlowsService) Get(ctx context.Context, id uuid.UUID) (*models.Flow, error) {
var flow models.Flow
path := joinPath("/flows", id.String())
if err := s.client.get(ctx, path, &flow); err != nil {
return nil, fmt.Errorf("failed to get flow: %w", err)
}
return &flow, nil
}
// GetByName retrieves a flow by name.
func (s *FlowsService) GetByName(ctx context.Context, name string) (*models.Flow, error) {
var flow models.Flow
path := buildPath("/flows/name/"+name, nil)
if err := s.client.get(ctx, path, &flow); err != nil {
return nil, fmt.Errorf("failed to get flow by name: %w", err)
}
return &flow, nil
}
// List retrieves a list of flows with optional filtering.
func (s *FlowsService) List(ctx context.Context, filter *models.FlowFilter, offset, limit int) (*pagination.PaginatedResponse[models.Flow], error) {
if filter == nil {
filter = &models.FlowFilter{}
}
filter.Offset = offset
filter.Limit = limit
type response struct {
Results []models.Flow `json:"results"`
Count int `json:"count"`
}
var resp response
if err := s.client.post(ctx, "/flows/filter", filter, &resp); err != nil {
return nil, fmt.Errorf("failed to list flows: %w", err)
}
return &pagination.PaginatedResponse[models.Flow]{
Results: resp.Results,
Count: resp.Count,
Limit: limit,
Offset: offset,
HasMore: offset+len(resp.Results) < resp.Count,
}, nil
}
// ListAll returns an iterator for all flows matching the filter.
func (s *FlowsService) ListAll(ctx context.Context, filter *models.FlowFilter) *pagination.Iterator[models.Flow] {
fetchFunc := func(ctx context.Context, offset, limit int) (*pagination.PaginatedResponse[models.Flow], error) {
return s.List(ctx, filter, offset, limit)
}
return pagination.NewIterator(fetchFunc, 100)
}
// Update updates a flow.
func (s *FlowsService) Update(ctx context.Context, id uuid.UUID, req *models.FlowUpdate) (*models.Flow, error) {
var flow models.Flow
path := joinPath("/flows", id.String())
if err := s.client.patch(ctx, path, req, &flow); err != nil {
return nil, fmt.Errorf("failed to update flow: %w", err)
}
return &flow, nil
}
// Delete deletes a flow by ID.
func (s *FlowsService) Delete(ctx context.Context, id uuid.UUID) error {
path := joinPath("/flows", id.String())
if err := s.client.delete(ctx, path); err != nil {
return fmt.Errorf("failed to delete flow: %w", err)
}
return nil
}
// Count returns the number of flows matching the filter.
func (s *FlowsService) Count(ctx context.Context, filter *models.FlowFilter) (int, error) {
if filter == nil {
filter = &models.FlowFilter{}
}
var result struct {
Count int `json:"count"`
}
if err := s.client.post(ctx, "/flows/count", filter, &result); err != nil {
return 0, fmt.Errorf("failed to count flows: %w", err)
}
return result.Count, nil
}

View File

@@ -0,0 +1,415 @@
//go:build integration
// +build integration
package client_test
import (
"context"
"os"
"testing"
"time"
"github.com/google/uuid"
"github.com/gregor/prefect-go/pkg/client"
"github.com/gregor/prefect-go/pkg/errors"
"github.com/gregor/prefect-go/pkg/models"
)
var testClient *client.Client
func TestMain(m *testing.M) {
// Get base URL from environment or use default
baseURL := os.Getenv("PREFECT_API_URL")
if baseURL == "" {
baseURL = "http://localhost:4200/api"
}
// Create test client
var err error
testClient, err = client.NewClient(
client.WithBaseURL(baseURL),
)
if err != nil {
panic(err)
}
// Run tests
os.Exit(m.Run())
}
func TestIntegration_FlowLifecycle(t *testing.T) {
ctx := context.Background()
// Create a flow
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
Name: "test-flow-" + uuid.New().String(),
Tags: []string{"integration-test"},
})
if err != nil {
t.Fatalf("Failed to create flow: %v", err)
}
defer testClient.Flows.Delete(ctx, flow.ID)
// Verify flow was created
if flow.ID == uuid.Nil {
t.Error("Flow ID is nil")
}
if flow.Name == "" {
t.Error("Flow name is empty")
}
// Get the flow
retrievedFlow, err := testClient.Flows.Get(ctx, flow.ID)
if err != nil {
t.Fatalf("Failed to get flow: %v", err)
}
if retrievedFlow.ID != flow.ID {
t.Errorf("Flow ID mismatch: got %v, want %v", retrievedFlow.ID, flow.ID)
}
// Update the flow
newTags := []string{"integration-test", "updated"}
updatedFlow, err := testClient.Flows.Update(ctx, flow.ID, &models.FlowUpdate{
Tags: &newTags,
})
if err != nil {
t.Fatalf("Failed to update flow: %v", err)
}
if len(updatedFlow.Tags) != 2 {
t.Errorf("Expected 2 tags, got %d", len(updatedFlow.Tags))
}
// List flows
flowsPage, err := testClient.Flows.List(ctx, &models.FlowFilter{
Tags: []string{"integration-test"},
}, 0, 10)
if err != nil {
t.Fatalf("Failed to list flows: %v", err)
}
if len(flowsPage.Results) == 0 {
t.Error("Expected at least one flow in results")
}
// Delete the flow
if err := testClient.Flows.Delete(ctx, flow.ID); err != nil {
t.Fatalf("Failed to delete flow: %v", err)
}
// Verify deletion
_, err = testClient.Flows.Get(ctx, flow.ID)
if !errors.IsNotFound(err) {
t.Errorf("Expected NotFound error, got: %v", err)
}
}
func TestIntegration_FlowRunLifecycle(t *testing.T) {
ctx := context.Background()
// Create a flow first
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
Name: "test-flow-run-" + uuid.New().String(),
Tags: []string{"integration-test"},
})
if err != nil {
t.Fatalf("Failed to create flow: %v", err)
}
defer testClient.Flows.Delete(ctx, flow.ID)
// Create a flow run
flowRun, err := testClient.FlowRuns.Create(ctx, &models.FlowRunCreate{
FlowID: flow.ID,
Name: "test-run",
Parameters: map[string]interface{}{
"test_param": "value",
},
})
if err != nil {
t.Fatalf("Failed to create flow run: %v", err)
}
defer testClient.FlowRuns.Delete(ctx, flowRun.ID)
// Verify flow run was created
if flowRun.ID == uuid.Nil {
t.Error("Flow run ID is nil")
}
if flowRun.FlowID != flow.ID {
t.Errorf("Flow ID mismatch: got %v, want %v", flowRun.FlowID, flow.ID)
}
// Get the flow run
retrievedRun, err := testClient.FlowRuns.Get(ctx, flowRun.ID)
if err != nil {
t.Fatalf("Failed to get flow run: %v", err)
}
if retrievedRun.ID != flowRun.ID {
t.Errorf("Flow run ID mismatch: got %v, want %v", retrievedRun.ID, flowRun.ID)
}
// Set state to RUNNING
runningState := models.StateTypeRunning
updatedRun, err := testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{
Type: runningState,
Message: strPtr("Test running"),
})
if err != nil {
t.Fatalf("Failed to set state: %v", err)
}
if updatedRun.StateType == nil || *updatedRun.StateType != runningState {
t.Errorf("State type mismatch: got %v, want %v", updatedRun.StateType, runningState)
}
// Set state to COMPLETED
completedState := models.StateTypeCompleted
completedRun, err := testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{
Type: completedState,
Message: strPtr("Test completed"),
})
if err != nil {
t.Fatalf("Failed to set state: %v", err)
}
if completedRun.StateType == nil || *completedRun.StateType != completedState {
t.Errorf("State type mismatch: got %v, want %v", completedRun.StateType, completedState)
}
// List flow runs
runsPage, err := testClient.FlowRuns.List(ctx, &models.FlowRunFilter{
FlowID: &flow.ID,
}, 0, 10)
if err != nil {
t.Fatalf("Failed to list flow runs: %v", err)
}
if len(runsPage.Results) == 0 {
t.Error("Expected at least one flow run in results")
}
}
func TestIntegration_DeploymentLifecycle(t *testing.T) {
ctx := context.Background()
// Create a flow first
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
Name: "test-deployment-flow-" + uuid.New().String(),
Tags: []string{"integration-test"},
})
if err != nil {
t.Fatalf("Failed to create flow: %v", err)
}
defer testClient.Flows.Delete(ctx, flow.ID)
// Create a deployment
workPoolName := "default-pool"
deployment, err := testClient.Deployments.Create(ctx, &models.DeploymentCreate{
Name: "test-deployment",
FlowID: flow.ID,
WorkPoolName: &workPoolName,
})
if err != nil {
t.Fatalf("Failed to create deployment: %v", err)
}
defer testClient.Deployments.Delete(ctx, deployment.ID)
// Verify deployment was created
if deployment.ID == uuid.Nil {
t.Error("Deployment ID is nil")
}
if deployment.FlowID != flow.ID {
t.Errorf("Flow ID mismatch: got %v, want %v", deployment.FlowID, flow.ID)
}
// Get the deployment
retrievedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID)
if err != nil {
t.Fatalf("Failed to get deployment: %v", err)
}
if retrievedDeployment.ID != deployment.ID {
t.Errorf("Deployment ID mismatch: got %v, want %v", retrievedDeployment.ID, deployment.ID)
}
// Pause the deployment
if err := testClient.Deployments.Pause(ctx, deployment.ID); err != nil {
t.Fatalf("Failed to pause deployment: %v", err)
}
// Verify paused
pausedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID)
if err != nil {
t.Fatalf("Failed to get deployment: %v", err)
}
if !pausedDeployment.Paused {
t.Error("Expected deployment to be paused")
}
// Resume the deployment
if err := testClient.Deployments.Resume(ctx, deployment.ID); err != nil {
t.Fatalf("Failed to resume deployment: %v", err)
}
// Verify resumed
resumedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID)
if err != nil {
t.Fatalf("Failed to get deployment: %v", err)
}
if resumedDeployment.Paused {
t.Error("Expected deployment to be resumed")
}
}
func TestIntegration_Pagination(t *testing.T) {
ctx := context.Background()
// Create multiple flows
createdFlows := make([]uuid.UUID, 0)
for i := 0; i < 15; i++ {
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
Name: "pagination-test-" + uuid.New().String(),
Tags: []string{"pagination-integration-test"},
})
if err != nil {
t.Fatalf("Failed to create flow %d: %v", i, err)
}
createdFlows = append(createdFlows, flow.ID)
}
// Clean up
defer func() {
for _, id := range createdFlows {
testClient.Flows.Delete(ctx, id)
}
}()
// Test manual pagination
page1, err := testClient.Flows.List(ctx, &models.FlowFilter{
Tags: []string{"pagination-integration-test"},
}, 0, 5)
if err != nil {
t.Fatalf("Failed to list flows: %v", err)
}
if len(page1.Results) != 5 {
t.Errorf("Expected 5 flows in page 1, got %d", len(page1.Results))
}
if !page1.HasMore {
t.Error("Expected more pages")
}
// Test iterator
iter := testClient.Flows.ListAll(ctx, &models.FlowFilter{
Tags: []string{"pagination-integration-test"},
})
count := 0
for iter.Next(ctx) {
count++
}
if err := iter.Err(); err != nil {
t.Fatalf("Iterator error: %v", err)
}
if count != 15 {
t.Errorf("Expected to iterate over 15 flows, got %d", count)
}
}
func TestIntegration_Variables(t *testing.T) {
ctx := context.Background()
// Create a variable
variable, err := testClient.Variables.Create(ctx, &models.VariableCreate{
Name: "test-var-" + uuid.New().String(),
Value: "test-value",
Tags: []string{"integration-test"},
})
if err != nil {
t.Fatalf("Failed to create variable: %v", err)
}
defer testClient.Variables.Delete(ctx, variable.ID)
// Get the variable
retrievedVar, err := testClient.Variables.Get(ctx, variable.ID)
if err != nil {
t.Fatalf("Failed to get variable: %v", err)
}
if retrievedVar.Value != "test-value" {
t.Errorf("Value mismatch: got %v, want test-value", retrievedVar.Value)
}
// Update the variable
newValue := "updated-value"
updatedVar, err := testClient.Variables.Update(ctx, variable.ID, &models.VariableUpdate{
Value: &newValue,
})
if err != nil {
t.Fatalf("Failed to update variable: %v", err)
}
if updatedVar.Value != newValue {
t.Errorf("Value mismatch: got %v, want %v", updatedVar.Value, newValue)
}
}
func TestIntegration_AdminEndpoints(t *testing.T) {
ctx := context.Background()
// Test health check
if err := testClient.Admin.Health(ctx); err != nil {
t.Fatalf("Health check failed: %v", err)
}
// Test version
version, err := testClient.Admin.Version(ctx)
if err != nil {
t.Fatalf("Failed to get version: %v", err)
}
if version == "" {
t.Error("Version is empty")
}
t.Logf("Server version: %s", version)
}
func TestIntegration_FlowRunWait(t *testing.T) {
ctx := context.Background()
// Create a flow
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
Name: "wait-test-" + uuid.New().String(),
Tags: []string{"integration-test"},
})
if err != nil {
t.Fatalf("Failed to create flow: %v", err)
}
defer testClient.Flows.Delete(ctx, flow.ID)
// Create a flow run
flowRun, err := testClient.FlowRuns.Create(ctx, &models.FlowRunCreate{
FlowID: flow.ID,
Name: "wait-test-run",
})
if err != nil {
t.Fatalf("Failed to create flow run: %v", err)
}
defer testClient.FlowRuns.Delete(ctx, flowRun.ID)
// Simulate completion in background
go func() {
time.Sleep(2 * time.Second)
completedState := models.StateTypeCompleted
testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{
Type: completedState,
Message: strPtr("Test completed"),
})
}()
// Wait for completion with timeout
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
finalRun, err := testClient.FlowRuns.Wait(waitCtx, flowRun.ID, time.Second)
if err != nil {
t.Fatalf("Failed to wait for flow run: %v", err)
}
if finalRun.StateType == nil || *finalRun.StateType != models.StateTypeCompleted {
t.Errorf("Expected COMPLETED state, got %v", finalRun.StateType)
}
}
func strPtr(s string) *string {
return &s
}

301
pkg/client/services.go Normal file
View File

@@ -0,0 +1,301 @@
package client
import (
"context"
"fmt"
"github.com/google/uuid"
"github.com/gregor/prefect-go/pkg/models"
"github.com/gregor/prefect-go/pkg/pagination"
)
// DeploymentsService handles operations related to deployments.
type DeploymentsService struct {
client *Client
}
// Create creates a new deployment.
func (s *DeploymentsService) Create(ctx context.Context, req *models.DeploymentCreate) (*models.Deployment, error) {
var deployment models.Deployment
if err := s.client.post(ctx, "/deployments/", req, &deployment); err != nil {
return nil, fmt.Errorf("failed to create deployment: %w", err)
}
return &deployment, nil
}
// Get retrieves a deployment by ID.
func (s *DeploymentsService) Get(ctx context.Context, id uuid.UUID) (*models.Deployment, error) {
var deployment models.Deployment
path := joinPath("/deployments", id.String())
if err := s.client.get(ctx, path, &deployment); err != nil {
return nil, fmt.Errorf("failed to get deployment: %w", err)
}
return &deployment, nil
}
// GetByName retrieves a deployment by flow name and deployment name.
func (s *DeploymentsService) GetByName(ctx context.Context, flowName, deploymentName string) (*models.Deployment, error) {
var deployment models.Deployment
path := fmt.Sprintf("/deployments/name/%s/%s", flowName, deploymentName)
if err := s.client.get(ctx, path, &deployment); err != nil {
return nil, fmt.Errorf("failed to get deployment by name: %w", err)
}
return &deployment, nil
}
// Update updates a deployment.
func (s *DeploymentsService) Update(ctx context.Context, id uuid.UUID, req *models.DeploymentUpdate) (*models.Deployment, error) {
var deployment models.Deployment
path := joinPath("/deployments", id.String())
if err := s.client.patch(ctx, path, req, &deployment); err != nil {
return nil, fmt.Errorf("failed to update deployment: %w", err)
}
return &deployment, nil
}
// Delete deletes a deployment by ID.
func (s *DeploymentsService) Delete(ctx context.Context, id uuid.UUID) error {
path := joinPath("/deployments", id.String())
if err := s.client.delete(ctx, path); err != nil {
return fmt.Errorf("failed to delete deployment: %w", err)
}
return nil
}
// Pause pauses a deployment's schedule.
func (s *DeploymentsService) Pause(ctx context.Context, id uuid.UUID) error {
path := joinPath("/deployments", id.String(), "pause")
if err := s.client.post(ctx, path, nil, nil); err != nil {
return fmt.Errorf("failed to pause deployment: %w", err)
}
return nil
}
// Resume resumes a deployment's schedule.
func (s *DeploymentsService) Resume(ctx context.Context, id uuid.UUID) error {
path := joinPath("/deployments", id.String(), "resume")
if err := s.client.post(ctx, path, nil, nil); err != nil {
return fmt.Errorf("failed to resume deployment: %w", err)
}
return nil
}
// CreateFlowRun creates a flow run from a deployment.
func (s *DeploymentsService) CreateFlowRun(ctx context.Context, id uuid.UUID, params map[string]interface{}) (*models.FlowRun, error) {
var flowRun models.FlowRun
path := joinPath("/deployments", id.String(), "create_flow_run")
req := struct {
Parameters map[string]interface{} `json:"parameters,omitempty"`
}{
Parameters: params,
}
if err := s.client.post(ctx, path, req, &flowRun); err != nil {
return nil, fmt.Errorf("failed to create flow run from deployment: %w", err)
}
return &flowRun, nil
}
// TaskRunsService handles operations related to task runs.
type TaskRunsService struct {
client *Client
}
// Create creates a new task run.
func (t *TaskRunsService) Create(ctx context.Context, req *models.TaskRunCreate) (*models.TaskRun, error) {
var taskRun models.TaskRun
if err := t.client.post(ctx, "/task_runs/", req, &taskRun); err != nil {
return nil, fmt.Errorf("failed to create task run: %w", err)
}
return &taskRun, nil
}
// Get retrieves a task run by ID.
func (t *TaskRunsService) Get(ctx context.Context, id uuid.UUID) (*models.TaskRun, error) {
var taskRun models.TaskRun
path := joinPath("/task_runs", id.String())
if err := t.client.get(ctx, path, &taskRun); err != nil {
return nil, fmt.Errorf("failed to get task run: %w", err)
}
return &taskRun, nil
}
// Delete deletes a task run by ID.
func (t *TaskRunsService) Delete(ctx context.Context, id uuid.UUID) error {
path := joinPath("/task_runs", id.String())
if err := t.client.delete(ctx, path); err != nil {
return fmt.Errorf("failed to delete task run: %w", err)
}
return nil
}
// SetState sets the state of a task run.
func (t *TaskRunsService) SetState(ctx context.Context, id uuid.UUID, state *models.StateCreate) (*models.TaskRun, error) {
var taskRun models.TaskRun
path := joinPath("/task_runs", id.String(), "set_state")
req := struct {
State *models.StateCreate `json:"state"`
}{
State: state,
}
if err := t.client.post(ctx, path, req, &taskRun); err != nil {
return nil, fmt.Errorf("failed to set task run state: %w", err)
}
return &taskRun, nil
}
// WorkPoolsService handles operations related to work pools.
type WorkPoolsService struct {
client *Client
}
// Get retrieves a work pool by name.
func (w *WorkPoolsService) Get(ctx context.Context, name string) (*models.WorkPool, error) {
var workPool models.WorkPool
path := joinPath("/work_pools", name)
if err := w.client.get(ctx, path, &workPool); err != nil {
return nil, fmt.Errorf("failed to get work pool: %w", err)
}
return &workPool, nil
}
// WorkQueuesService handles operations related to work queues.
type WorkQueuesService struct {
client *Client
}
// Get retrieves a work queue by ID.
func (w *WorkQueuesService) Get(ctx context.Context, id uuid.UUID) (*models.WorkQueue, error) {
var workQueue models.WorkQueue
path := joinPath("/work_queues", id.String())
if err := w.client.get(ctx, path, &workQueue); err != nil {
return nil, fmt.Errorf("failed to get work queue: %w", err)
}
return &workQueue, nil
}
// VariablesService handles operations related to variables.
type VariablesService struct {
client *Client
}
// Create creates a new variable.
func (v *VariablesService) Create(ctx context.Context, req *models.VariableCreate) (*models.Variable, error) {
var variable models.Variable
if err := v.client.post(ctx, "/variables/", req, &variable); err != nil {
return nil, fmt.Errorf("failed to create variable: %w", err)
}
return &variable, nil
}
// Get retrieves a variable by ID.
func (v *VariablesService) Get(ctx context.Context, id uuid.UUID) (*models.Variable, error) {
var variable models.Variable
path := joinPath("/variables", id.String())
if err := v.client.get(ctx, path, &variable); err != nil {
return nil, fmt.Errorf("failed to get variable: %w", err)
}
return &variable, nil
}
// GetByName retrieves a variable by name.
func (v *VariablesService) GetByName(ctx context.Context, name string) (*models.Variable, error) {
var variable models.Variable
path := joinPath("/variables/name", name)
if err := v.client.get(ctx, path, &variable); err != nil {
return nil, fmt.Errorf("failed to get variable by name: %w", err)
}
return &variable, nil
}
// Update updates a variable.
func (v *VariablesService) Update(ctx context.Context, id uuid.UUID, req *models.VariableUpdate) (*models.Variable, error) {
var variable models.Variable
path := joinPath("/variables", id.String())
if err := v.client.patch(ctx, path, req, &variable); err != nil {
return nil, fmt.Errorf("failed to update variable: %w", err)
}
return &variable, nil
}
// Delete deletes a variable by ID.
func (v *VariablesService) Delete(ctx context.Context, id uuid.UUID) error {
path := joinPath("/variables", id.String())
if err := v.client.delete(ctx, path); err != nil {
return fmt.Errorf("failed to delete variable: %w", err)
}
return nil
}
// LogsService handles operations related to logs.
type LogsService struct {
client *Client
}
// Create creates new log entries.
func (l *LogsService) Create(ctx context.Context, logs []*models.LogCreate) error {
if err := l.client.post(ctx, "/logs/", logs, nil); err != nil {
return fmt.Errorf("failed to create logs: %w", err)
}
return nil
}
// List retrieves logs with filtering.
func (l *LogsService) List(ctx context.Context, filter interface{}, offset, limit int) (*pagination.PaginatedResponse[models.Log], error) {
type request struct {
Filter interface{} `json:"filter,omitempty"`
Offset int `json:"offset"`
Limit int `json:"limit"`
}
req := request{
Filter: filter,
Offset: offset,
Limit: limit,
}
type response struct {
Results []models.Log `json:"results"`
Count int `json:"count"`
}
var resp response
if err := l.client.post(ctx, "/logs/filter", req, &resp); err != nil {
return nil, fmt.Errorf("failed to list logs: %w", err)
}
return &pagination.PaginatedResponse[models.Log]{
Results: resp.Results,
Count: resp.Count,
Limit: limit,
Offset: offset,
HasMore: offset+len(resp.Results) < resp.Count,
}, nil
}
// AdminService handles administrative operations.
type AdminService struct {
client *Client
}
// Health checks the health of the Prefect server.
func (a *AdminService) Health(ctx context.Context) error {
if err := a.client.get(ctx, "/health", nil); err != nil {
return fmt.Errorf("health check failed: %w", err)
}
return nil
}
// Version retrieves the server version.
func (a *AdminService) Version(ctx context.Context) (string, error) {
var result struct {
Version string `json:"version"`
}
if err := a.client.get(ctx, "/version", &result); err != nil {
return "", fmt.Errorf("failed to get version: %w", err)
}
return result.Version, nil
}

226
pkg/errors/errors.go Normal file
View File

@@ -0,0 +1,226 @@
// Package errors provides structured error types for the Prefect API client.
package errors
import (
"encoding/json"
"fmt"
"io"
"net/http"
)
// APIError represents an error returned by the Prefect API.
type APIError struct {
// StatusCode is the HTTP status code
StatusCode int
// Message is the error message
Message string
// Details contains additional error details from the API
Details map[string]interface{}
// RequestID is the request ID if available
RequestID string
// Response is the raw HTTP response
Response *http.Response
}
// Error implements the error interface.
func (e *APIError) Error() string {
if e.Message != "" {
return fmt.Sprintf("prefect api error (status %d): %s", e.StatusCode, e.Message)
}
return fmt.Sprintf("prefect api error (status %d)", e.StatusCode)
}
// IsNotFound returns true if the error is a 404 Not Found error.
func (e *APIError) IsNotFound() bool {
return e.StatusCode == http.StatusNotFound
}
// IsUnauthorized returns true if the error is a 401 Unauthorized error.
func (e *APIError) IsUnauthorized() bool {
return e.StatusCode == http.StatusUnauthorized
}
// IsForbidden returns true if the error is a 403 Forbidden error.
func (e *APIError) IsForbidden() bool {
return e.StatusCode == http.StatusForbidden
}
// IsRateLimited returns true if the error is a 429 Too Many Requests error.
func (e *APIError) IsRateLimited() bool {
return e.StatusCode == http.StatusTooManyRequests
}
// IsServerError returns true if the error is a 5xx server error.
func (e *APIError) IsServerError() bool {
return e.StatusCode >= 500 && e.StatusCode < 600
}
// ValidationError represents a validation error from the API.
type ValidationError struct {
*APIError
ValidationErrors []ValidationDetail
}
// ValidationDetail represents a single validation error detail.
type ValidationDetail struct {
Loc []string `json:"loc"`
Msg string `json:"msg"`
Type string `json:"type"`
Ctx interface{} `json:"ctx,omitempty"`
}
// Error implements the error interface.
func (e *ValidationError) Error() string {
if len(e.ValidationErrors) == 0 {
return e.APIError.Error()
}
return fmt.Sprintf("validation error: %s", e.ValidationErrors[0].Msg)
}
// NewAPIError creates a new APIError from an HTTP response.
func NewAPIError(resp *http.Response) error {
if resp == nil {
return &APIError{
StatusCode: 0,
Message: "nil response",
}
}
apiErr := &APIError{
StatusCode: resp.StatusCode,
Response: resp,
RequestID: resp.Header.Get("X-Request-ID"),
}
// Try to read and parse the response body
if resp.Body != nil {
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err == nil && len(body) > 0 {
// Try to parse as JSON
var errorResponse struct {
Detail interface{} `json:"detail"`
}
if json.Unmarshal(body, &errorResponse) == nil {
// Check if detail is a validation error (422)
if resp.StatusCode == http.StatusUnprocessableEntity {
return parseValidationError(apiErr, errorResponse.Detail)
}
// Handle string detail
if msg, ok := errorResponse.Detail.(string); ok {
apiErr.Message = msg
} else if details, ok := errorResponse.Detail.(map[string]interface{}); ok {
apiErr.Details = details
if msg, ok := details["message"].(string); ok {
apiErr.Message = msg
}
}
} else {
// Not JSON, use raw body as message
apiErr.Message = string(body)
}
}
}
// Fallback to status text if no message
if apiErr.Message == "" {
apiErr.Message = http.StatusText(resp.StatusCode)
}
return apiErr
}
// parseValidationError parses validation errors from the API response.
func parseValidationError(apiErr *APIError, detail interface{}) error {
valErr := &ValidationError{
APIError: apiErr,
}
// Detail can be a list of validation errors
if details, ok := detail.([]interface{}); ok {
for _, d := range details {
if detailMap, ok := d.(map[string]interface{}); ok {
var vd ValidationDetail
// Parse loc
if loc, ok := detailMap["loc"].([]interface{}); ok {
for _, l := range loc {
if s, ok := l.(string); ok {
vd.Loc = append(vd.Loc, s)
}
}
}
// Parse msg
if msg, ok := detailMap["msg"].(string); ok {
vd.Msg = msg
}
// Parse type
if typ, ok := detailMap["type"].(string); ok {
vd.Type = typ
}
// Parse ctx
if ctx, ok := detailMap["ctx"]; ok {
vd.Ctx = ctx
}
valErr.ValidationErrors = append(valErr.ValidationErrors, vd)
}
}
}
return valErr
}
// IsNotFound checks if an error is a 404 Not Found error.
func IsNotFound(err error) bool {
if apiErr, ok := err.(*APIError); ok {
return apiErr.IsNotFound()
}
return false
}
// IsUnauthorized checks if an error is a 401 Unauthorized error.
func IsUnauthorized(err error) bool {
if apiErr, ok := err.(*APIError); ok {
return apiErr.IsUnauthorized()
}
return false
}
// IsForbidden checks if an error is a 403 Forbidden error.
func IsForbidden(err error) bool {
if apiErr, ok := err.(*APIError); ok {
return apiErr.IsForbidden()
}
return false
}
// IsRateLimited checks if an error is a 429 Too Many Requests error.
func IsRateLimited(err error) bool {
if apiErr, ok := err.(*APIError); ok {
return apiErr.IsRateLimited()
}
return false
}
// IsServerError checks if an error is a 5xx server error.
func IsServerError(err error) bool {
if apiErr, ok := err.(*APIError); ok {
return apiErr.IsServerError()
}
return false
}
// IsValidationError checks if an error is a validation error.
func IsValidationError(err error) bool {
_, ok := err.(*ValidationError)
return ok
}

240
pkg/errors/errors_test.go Normal file
View File

@@ -0,0 +1,240 @@
package errors
import (
"io"
"net/http"
"strings"
"testing"
)
func TestAPIError_Error(t *testing.T) {
tests := []struct {
name string
err *APIError
want string
}{
{
name: "with message",
err: &APIError{
StatusCode: 404,
Message: "Flow not found",
},
want: "prefect api error (status 404): Flow not found",
},
{
name: "without message",
err: &APIError{
StatusCode: 500,
},
want: "prefect api error (status 500)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.err.Error(); got != tt.want {
t.Errorf("Error() = %v, want %v", got, tt.want)
}
})
}
}
func TestAPIError_Checks(t *testing.T) {
tests := []struct {
name string
statusCode int
checkFuncs map[string]func(*APIError) bool
want map[string]bool
}{
{
name: "404 not found",
statusCode: 404,
checkFuncs: map[string]func(*APIError) bool{
"IsNotFound": (*APIError).IsNotFound,
"IsUnauthorized": (*APIError).IsUnauthorized,
"IsForbidden": (*APIError).IsForbidden,
"IsRateLimited": (*APIError).IsRateLimited,
"IsServerError": (*APIError).IsServerError,
},
want: map[string]bool{
"IsNotFound": true,
"IsUnauthorized": false,
"IsForbidden": false,
"IsRateLimited": false,
"IsServerError": false,
},
},
{
name: "401 unauthorized",
statusCode: 401,
checkFuncs: map[string]func(*APIError) bool{
"IsNotFound": (*APIError).IsNotFound,
"IsUnauthorized": (*APIError).IsUnauthorized,
"IsForbidden": (*APIError).IsForbidden,
"IsRateLimited": (*APIError).IsRateLimited,
"IsServerError": (*APIError).IsServerError,
},
want: map[string]bool{
"IsNotFound": false,
"IsUnauthorized": true,
"IsForbidden": false,
"IsRateLimited": false,
"IsServerError": false,
},
},
{
name: "500 server error",
statusCode: 500,
checkFuncs: map[string]func(*APIError) bool{
"IsNotFound": (*APIError).IsNotFound,
"IsUnauthorized": (*APIError).IsUnauthorized,
"IsForbidden": (*APIError).IsForbidden,
"IsRateLimited": (*APIError).IsRateLimited,
"IsServerError": (*APIError).IsServerError,
},
want: map[string]bool{
"IsNotFound": false,
"IsUnauthorized": false,
"IsForbidden": false,
"IsRateLimited": false,
"IsServerError": true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := &APIError{StatusCode: tt.statusCode}
for checkName, checkFunc := range tt.checkFuncs {
if got := checkFunc(err); got != tt.want[checkName] {
t.Errorf("%s() = %v, want %v", checkName, got, tt.want[checkName])
}
}
})
}
}
func TestNewAPIError(t *testing.T) {
tests := []struct {
name string
resp *http.Response
wantStatus int
wantMsg string
}{
{
name: "nil response",
resp: nil,
wantStatus: 0,
wantMsg: "nil response",
},
{
name: "404 with JSON body",
resp: &http.Response{
StatusCode: 404,
Body: io.NopCloser(strings.NewReader(`{"detail": "Flow not found"}`)),
Header: http.Header{},
},
wantStatus: 404,
wantMsg: "Flow not found",
},
{
name: "500 without body",
resp: &http.Response{
StatusCode: 500,
Body: io.NopCloser(strings.NewReader("")),
Header: http.Header{},
},
wantStatus: 500,
wantMsg: "Internal Server Error",
},
{
name: "with request ID",
resp: &http.Response{
StatusCode: 400,
Body: io.NopCloser(strings.NewReader(`{"detail": "Bad request"}`)),
Header: http.Header{
"X-Request-ID": []string{"req-123"},
},
},
wantStatus: 400,
wantMsg: "Bad request",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := NewAPIError(tt.resp)
apiErr, ok := err.(*APIError)
if !ok {
t.Fatalf("expected *APIError, got %T", err)
}
if apiErr.StatusCode != tt.wantStatus {
t.Errorf("StatusCode = %v, want %v", apiErr.StatusCode, tt.wantStatus)
}
if apiErr.Message != tt.wantMsg {
t.Errorf("Message = %v, want %v", apiErr.Message, tt.wantMsg)
}
})
}
}
func TestNewAPIError_Validation(t *testing.T) {
resp := &http.Response{
StatusCode: 422,
Body: io.NopCloser(strings.NewReader(`{
"detail": [
{
"loc": ["body", "name"],
"msg": "field required",
"type": "value_error.missing"
}
]
}`)),
Header: http.Header{},
}
err := NewAPIError(resp)
valErr, ok := err.(*ValidationError)
if !ok {
t.Fatalf("expected *ValidationError, got %T", err)
}
if valErr.StatusCode != 422 {
t.Errorf("StatusCode = %v, want 422", valErr.StatusCode)
}
if len(valErr.ValidationErrors) != 1 {
t.Fatalf("expected 1 validation error, got %d", len(valErr.ValidationErrors))
}
vd := valErr.ValidationErrors[0]
if len(vd.Loc) != 2 || vd.Loc[0] != "body" || vd.Loc[1] != "name" {
t.Errorf("Loc = %v, want [body name]", vd.Loc)
}
if vd.Msg != "field required" {
t.Errorf("Msg = %v, want 'field required'", vd.Msg)
}
if vd.Type != "value_error.missing" {
t.Errorf("Type = %v, want 'value_error.missing'", vd.Type)
}
}
func TestHelperFunctions(t *testing.T) {
notFoundErr := &APIError{StatusCode: 404}
unauthorizedErr := &APIError{StatusCode: 401}
serverErr := &APIError{StatusCode: 500}
valErr := &ValidationError{APIError: &APIError{StatusCode: 422}}
if !IsNotFound(notFoundErr) {
t.Error("IsNotFound should return true for 404 error")
}
if !IsUnauthorized(unauthorizedErr) {
t.Error("IsUnauthorized should return true for 401 error")
}
if !IsServerError(serverErr) {
t.Error("IsServerError should return true for 500 error")
}
if !IsValidationError(valErr) {
t.Error("IsValidationError should return true for validation error")
}
}

326
pkg/models/models.go Normal file
View File

@@ -0,0 +1,326 @@
// Package models contains the data structures for the Prefect API.
package models
import (
"encoding/json"
"time"
"github.com/google/uuid"
)
// StateType represents the type of a flow or task run state.
type StateType string
const (
StateTypePending StateType = "PENDING"
StateTypeRunning StateType = "RUNNING"
StateTypeCompleted StateType = "COMPLETED"
StateTypeFailed StateType = "FAILED"
StateTypeCancelled StateType = "CANCELLED"
StateTypeCrashed StateType = "CRASHED"
StateTypePaused StateType = "PAUSED"
StateTypeScheduled StateType = "SCHEDULED"
StateTypeCancelling StateType = "CANCELLING"
)
// DeploymentStatus represents the status of a deployment.
type DeploymentStatus string
const (
DeploymentStatusReady DeploymentStatus = "READY"
DeploymentStatusNotReady DeploymentStatus = "NOT_READY"
)
// CollisionStrategy represents the strategy for handling concurrent flow runs.
type CollisionStrategy string
const (
CollisionStrategyEnqueue CollisionStrategy = "ENQUEUE"
CollisionStrategyCancelNew CollisionStrategy = "CANCEL_NEW"
)
// Flow represents a Prefect flow.
type Flow struct {
ID uuid.UUID `json:"id"`
Created *time.Time `json:"created"`
Updated *time.Time `json:"updated"`
Name string `json:"name"`
Tags []string `json:"tags,omitempty"`
Labels map[string]interface{} `json:"labels,omitempty"`
}
// FlowCreate represents the request to create a flow.
type FlowCreate struct {
Name string `json:"name"`
Tags []string `json:"tags,omitempty"`
Labels map[string]interface{} `json:"labels,omitempty"`
}
// FlowUpdate represents the request to update a flow.
type FlowUpdate struct {
Tags *[]string `json:"tags,omitempty"`
Labels *map[string]interface{} `json:"labels,omitempty"`
}
// FlowFilter represents filter criteria for querying flows.
type FlowFilter struct {
Name *string `json:"name,omitempty"`
Tags []string `json:"tags,omitempty"`
Offset int `json:"offset,omitempty"`
Limit int `json:"limit,omitempty"`
}
// FlowRun represents a Prefect flow run.
type FlowRun struct {
ID uuid.UUID `json:"id"`
Created *time.Time `json:"created"`
Updated *time.Time `json:"updated"`
FlowID uuid.UUID `json:"flow_id"`
Name string `json:"name"`
StateID *uuid.UUID `json:"state_id"`
DeploymentID *uuid.UUID `json:"deployment_id"`
WorkQueueID *uuid.UUID `json:"work_queue_id"`
WorkQueueName *string `json:"work_queue_name"`
FlowVersion *string `json:"flow_version"`
Parameters map[string]interface{} `json:"parameters,omitempty"`
IdempotencyKey *string `json:"idempotency_key"`
Context map[string]interface{} `json:"context,omitempty"`
Tags []string `json:"tags,omitempty"`
Labels map[string]interface{} `json:"labels,omitempty"`
ParentTaskRunID *uuid.UUID `json:"parent_task_run_id"`
StateType *StateType `json:"state_type"`
StateName *string `json:"state_name"`
RunCount int `json:"run_count"`
ExpectedStartTime *time.Time `json:"expected_start_time"`
NextScheduledStartTime *time.Time `json:"next_scheduled_start_time"`
StartTime *time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time"`
TotalRunTime float64 `json:"total_run_time"`
State *State `json:"state"`
}
// FlowRunCreate represents the request to create a flow run.
type FlowRunCreate struct {
FlowID uuid.UUID `json:"flow_id"`
Name string `json:"name,omitempty"`
State *StateCreate `json:"state,omitempty"`
Parameters map[string]interface{} `json:"parameters,omitempty"`
Context map[string]interface{} `json:"context,omitempty"`
Tags []string `json:"tags,omitempty"`
Labels map[string]interface{} `json:"labels,omitempty"`
IdempotencyKey *string `json:"idempotency_key,omitempty"`
WorkPoolName *string `json:"work_pool_name,omitempty"`
WorkQueueName *string `json:"work_queue_name,omitempty"`
DeploymentID *uuid.UUID `json:"deployment_id,omitempty"`
}
// FlowRunUpdate represents the request to update a flow run.
type FlowRunUpdate struct {
Name *string `json:"name,omitempty"`
}
// FlowRunFilter represents filter criteria for querying flow runs.
type FlowRunFilter struct {
FlowID *uuid.UUID `json:"flow_id,omitempty"`
DeploymentID *uuid.UUID `json:"deployment_id,omitempty"`
StateType *StateType `json:"state_type,omitempty"`
Tags []string `json:"tags,omitempty"`
Offset int `json:"offset,omitempty"`
Limit int `json:"limit,omitempty"`
}
// State represents the state of a flow or task run.
type State struct {
ID uuid.UUID `json:"id"`
Type StateType `json:"type"`
Name *string `json:"name"`
Timestamp time.Time `json:"timestamp"`
Message *string `json:"message"`
Data interface{} `json:"data,omitempty"`
StateDetails map[string]interface{} `json:"state_details,omitempty"`
}
// StateCreate represents the request to create a state.
type StateCreate struct {
Type StateType `json:"type"`
Name *string `json:"name,omitempty"`
Message *string `json:"message,omitempty"`
Data interface{} `json:"data,omitempty"`
StateDetails map[string]interface{} `json:"state_details,omitempty"`
}
// Deployment represents a Prefect deployment.
type Deployment struct {
ID uuid.UUID `json:"id"`
Created *time.Time `json:"created"`
Updated *time.Time `json:"updated"`
Name string `json:"name"`
FlowID uuid.UUID `json:"flow_id"`
Version *string `json:"version"`
Description *string `json:"description"`
Paused bool `json:"paused"`
Parameters map[string]interface{} `json:"parameters,omitempty"`
Tags []string `json:"tags,omitempty"`
Labels map[string]interface{} `json:"labels,omitempty"`
WorkQueueName *string `json:"work_queue_name"`
WorkPoolName *string `json:"work_pool_name"`
Path *string `json:"path"`
Entrypoint *string `json:"entrypoint"`
Status *DeploymentStatus `json:"status"`
EnforceParameterSchema bool `json:"enforce_parameter_schema"`
}
// DeploymentCreate represents the request to create a deployment.
type DeploymentCreate struct {
Name string `json:"name"`
FlowID uuid.UUID `json:"flow_id"`
Paused bool `json:"paused,omitempty"`
Description *string `json:"description,omitempty"`
Version *string `json:"version,omitempty"`
Parameters map[string]interface{} `json:"parameters,omitempty"`
Tags []string `json:"tags,omitempty"`
Labels map[string]interface{} `json:"labels,omitempty"`
WorkPoolName *string `json:"work_pool_name,omitempty"`
WorkQueueName *string `json:"work_queue_name,omitempty"`
Path *string `json:"path,omitempty"`
Entrypoint *string `json:"entrypoint,omitempty"`
EnforceParameterSchema bool `json:"enforce_parameter_schema,omitempty"`
}
// DeploymentUpdate represents the request to update a deployment.
type DeploymentUpdate struct {
Paused *bool `json:"paused,omitempty"`
Description *string `json:"description,omitempty"`
}
// TaskRun represents a Prefect task run.
type TaskRun struct {
ID uuid.UUID `json:"id"`
Created *time.Time `json:"created"`
Updated *time.Time `json:"updated"`
Name string `json:"name"`
FlowRunID uuid.UUID `json:"flow_run_id"`
TaskKey string `json:"task_key"`
DynamicKey string `json:"dynamic_key"`
CacheKey *string `json:"cache_key"`
StartTime *time.Time `json:"start_time"`
EndTime *time.Time `json:"end_time"`
TotalRunTime float64 `json:"total_run_time"`
Status *StateType `json:"status"`
StateID *uuid.UUID `json:"state_id"`
Tags []string `json:"tags,omitempty"`
State *State `json:"state"`
}
// TaskRunCreate represents the request to create a task run.
type TaskRunCreate struct {
Name string `json:"name"`
FlowRunID uuid.UUID `json:"flow_run_id"`
TaskKey string `json:"task_key"`
DynamicKey string `json:"dynamic_key"`
CacheKey *string `json:"cache_key,omitempty"`
State *StateCreate `json:"state,omitempty"`
Tags []string `json:"tags,omitempty"`
}
// WorkPool represents a Prefect work pool.
type WorkPool struct {
ID uuid.UUID `json:"id"`
Created *time.Time `json:"created"`
Updated *time.Time `json:"updated"`
Name string `json:"name"`
Description *string `json:"description"`
Type string `json:"type"`
IsPaused bool `json:"is_paused"`
}
// WorkQueue represents a Prefect work queue.
type WorkQueue struct {
ID uuid.UUID `json:"id"`
Created *time.Time `json:"created"`
Updated *time.Time `json:"updated"`
Name string `json:"name"`
Description *string `json:"description"`
IsPaused bool `json:"is_paused"`
Priority int `json:"priority"`
}
// Variable represents a Prefect variable.
type Variable struct {
ID uuid.UUID `json:"id"`
Created *time.Time `json:"created"`
Updated *time.Time `json:"updated"`
Name string `json:"name"`
Value string `json:"value"`
Tags []string `json:"tags,omitempty"`
}
// VariableCreate represents the request to create a variable.
type VariableCreate struct {
Name string `json:"name"`
Value string `json:"value"`
Tags []string `json:"tags,omitempty"`
}
// VariableUpdate represents the request to update a variable.
type VariableUpdate struct {
Value *string `json:"value,omitempty"`
Tags []string `json:"tags,omitempty"`
}
// Log represents a Prefect log entry.
type Log struct {
ID uuid.UUID `json:"id"`
Created *time.Time `json:"created"`
Updated *time.Time `json:"updated"`
Name string `json:"name"`
Level int `json:"level"`
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
FlowRunID *uuid.UUID `json:"flow_run_id"`
TaskRunID *uuid.UUID `json:"task_run_id"`
}
// LogCreate represents the request to create a log.
type LogCreate struct {
Name string `json:"name"`
Level int `json:"level"`
Message string `json:"message"`
Timestamp time.Time `json:"timestamp"`
FlowRunID *uuid.UUID `json:"flow_run_id,omitempty"`
TaskRunID *uuid.UUID `json:"task_run_id,omitempty"`
}
// UnmarshalJSON implements custom JSON unmarshaling for time fields.
func (f *Flow) UnmarshalJSON(data []byte) error {
type Alias Flow
aux := &struct {
Created string `json:"created"`
Updated string `json:"updated"`
*Alias
}{
Alias: (*Alias)(f),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if aux.Created != "" {
t, err := time.Parse(time.RFC3339, aux.Created)
if err != nil {
return err
}
f.Created = &t
}
if aux.Updated != "" {
t, err := time.Parse(time.RFC3339, aux.Updated)
if err != nil {
return err
}
f.Updated = &t
}
return nil
}

View File

@@ -0,0 +1,176 @@
// Package pagination provides pagination support for API responses.
package pagination
import (
"context"
"fmt"
)
// PaginatedResponse represents a paginated response from the API.
type PaginatedResponse[T any] struct {
// Results contains the items in this page
Results []T
// Count is the total number of items (if available)
Count int
// Limit is the maximum number of items per page
Limit int
// Offset is the offset of the first item in this page
Offset int
// HasMore indicates if there are more items to fetch
HasMore bool
}
// FetchFunc is a function that fetches a page of results.
type FetchFunc[T any] func(ctx context.Context, offset, limit int) (*PaginatedResponse[T], error)
// Iterator provides iteration over paginated results.
type Iterator[T any] struct {
fetchFunc FetchFunc[T]
limit int
offset int
current *PaginatedResponse[T]
index int
err error
done bool
}
// NewIterator creates a new pagination iterator.
func NewIterator[T any](fetchFunc FetchFunc[T], limit int) *Iterator[T] {
if limit <= 0 {
limit = 100 // Default page size
}
return &Iterator[T]{
fetchFunc: fetchFunc,
limit: limit,
offset: 0,
index: -1,
}
}
// Next advances the iterator to the next item.
// It returns true if there is an item available, false otherwise.
func (i *Iterator[T]) Next(ctx context.Context) bool {
if i.done || i.err != nil {
return false
}
// If we don't have a current page or we've reached the end of it, fetch the next page
if i.current == nil || i.index >= len(i.current.Results)-1 {
// Check if we know there are no more pages
if i.current != nil && !i.current.HasMore {
i.done = true
return false
}
// Fetch the next page
page, err := i.fetchFunc(ctx, i.offset, i.limit)
if err != nil {
i.err = err
return false
}
// Check if the page is empty
if page == nil || len(page.Results) == 0 {
i.done = true
return false
}
i.current = page
i.index = -1
i.offset += i.limit
}
// Move to the next item in the current page
i.index++
return i.index < len(i.current.Results)
}
// Value returns the current item.
// It should only be called after Next returns true.
func (i *Iterator[T]) Value() *T {
if i.current == nil || i.index < 0 || i.index >= len(i.current.Results) {
return nil
}
return &i.current.Results[i.index]
}
// Err returns any error that occurred during iteration.
func (i *Iterator[T]) Err() error {
return i.err
}
// Reset resets the iterator to the beginning.
func (i *Iterator[T]) Reset() {
i.offset = 0
i.index = -1
i.current = nil
i.err = nil
i.done = false
}
// Collect fetches all items from all pages and returns them as a slice.
// Warning: This can be memory intensive for large result sets.
func (i *Iterator[T]) Collect(ctx context.Context) ([]T, error) {
var results []T
for i.Next(ctx) {
if item := i.Value(); item != nil {
results = append(results, *item)
}
}
if err := i.Err(); err != nil {
return nil, fmt.Errorf("failed to collect all items: %w", err)
}
return results, nil
}
// CollectWithLimit fetches items up to a maximum count.
func (i *Iterator[T]) CollectWithLimit(ctx context.Context, maxItems int) ([]T, error) {
var results []T
count := 0
for i.Next(ctx) && count < maxItems {
if item := i.Value(); item != nil {
results = append(results, *item)
count++
}
}
if err := i.Err(); err != nil {
return nil, fmt.Errorf("failed to collect items: %w", err)
}
return results, nil
}
// Page represents a single page of results with metadata.
type Page[T any] struct {
Items []T
Offset int
Limit int
Total int
}
// HasNextPage returns true if there are more pages after this one.
func (p *Page[T]) HasNextPage() bool {
return p.Offset+len(p.Items) < p.Total
}
// HasPrevPage returns true if there are pages before this one.
func (p *Page[T]) HasPrevPage() bool {
return p.Offset > 0
}
// NextOffset returns the offset for the next page.
func (p *Page[T]) NextOffset() int {
return p.Offset + p.Limit
}
// PrevOffset returns the offset for the previous page.
func (p *Page[T]) PrevOffset() int {
offset := p.Offset - p.Limit
if offset < 0 {
return 0
}
return offset
}

View File

@@ -0,0 +1,339 @@
package pagination
import (
"context"
"errors"
"testing"
)
func TestNewIterator(t *testing.T) {
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
return nil, nil
}
tests := []struct {
name string
limit int
wantLimit int
}{
{"default limit", 0, 100},
{"custom limit", 50, 50},
{"negative limit", -10, 100},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
iter := NewIterator(fetchFunc, tt.limit)
if iter.limit != tt.wantLimit {
t.Errorf("limit = %d, want %d", iter.limit, tt.wantLimit)
}
})
}
}
func TestIterator_Next(t *testing.T) {
items := []string{"item1", "item2", "item3", "item4", "item5"}
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
if offset >= len(items) {
return &PaginatedResponse[string]{
Results: []string{},
Count: len(items),
Limit: limit,
Offset: offset,
HasMore: false,
}, nil
}
end := offset + limit
if end > len(items) {
end = len(items)
}
return &PaginatedResponse[string]{
Results: items[offset:end],
Count: len(items),
Limit: limit,
Offset: offset,
HasMore: end < len(items),
}, nil
}
ctx := context.Background()
iter := NewIterator(fetchFunc, 2)
var results []string
for iter.Next(ctx) {
if item := iter.Value(); item != nil {
results = append(results, *item)
}
}
if err := iter.Err(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != len(items) {
t.Errorf("got %d items, want %d", len(results), len(items))
}
for i, want := range items {
if i >= len(results) || results[i] != want {
t.Errorf("item[%d] = %v, want %v", i, results[i], want)
}
}
}
func TestIterator_Error(t *testing.T) {
expectedErr := errors.New("fetch error")
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
return nil, expectedErr
}
ctx := context.Background()
iter := NewIterator(fetchFunc, 10)
if iter.Next(ctx) {
t.Error("Next should return false on error")
}
if err := iter.Err(); err != expectedErr {
t.Errorf("Err() = %v, want %v", err, expectedErr)
}
}
func TestIterator_EmptyResults(t *testing.T) {
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
return &PaginatedResponse[string]{
Results: []string{},
Count: 0,
Limit: limit,
Offset: offset,
HasMore: false,
}, nil
}
ctx := context.Background()
iter := NewIterator(fetchFunc, 10)
if iter.Next(ctx) {
t.Error("Next should return false for empty results")
}
}
func TestIterator_Reset(t *testing.T) {
items := []string{"item1", "item2", "item3"}
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
if offset >= len(items) {
return &PaginatedResponse[string]{
Results: []string{},
HasMore: false,
}, nil
}
return &PaginatedResponse[string]{
Results: items[offset:],
HasMore: false,
}, nil
}
ctx := context.Background()
iter := NewIterator(fetchFunc, 10)
// First iteration
count1 := 0
for iter.Next(ctx) {
count1++
}
// Reset and iterate again
iter.Reset()
count2 := 0
for iter.Next(ctx) {
count2++
}
if count1 != count2 {
t.Errorf("counts don't match after reset: %d vs %d", count1, count2)
}
if count1 != len(items) {
t.Errorf("got %d items, want %d", count1, len(items))
}
}
func TestIterator_Collect(t *testing.T) {
items := []string{"item1", "item2", "item3", "item4", "item5"}
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
if offset >= len(items) {
return &PaginatedResponse[string]{
Results: []string{},
HasMore: false,
}, nil
}
end := offset + limit
if end > len(items) {
end = len(items)
}
return &PaginatedResponse[string]{
Results: items[offset:end],
HasMore: end < len(items),
}, nil
}
ctx := context.Background()
iter := NewIterator(fetchFunc, 2)
results, err := iter.Collect(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != len(items) {
t.Errorf("got %d items, want %d", len(results), len(items))
}
}
func TestIterator_CollectWithLimit(t *testing.T) {
items := []string{"item1", "item2", "item3", "item4", "item5"}
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
if offset >= len(items) {
return &PaginatedResponse[string]{
Results: []string{},
HasMore: false,
}, nil
}
end := offset + limit
if end > len(items) {
end = len(items)
}
return &PaginatedResponse[string]{
Results: items[offset:end],
HasMore: end < len(items),
}, nil
}
ctx := context.Background()
iter := NewIterator(fetchFunc, 2)
maxItems := 3
results, err := iter.CollectWithLimit(ctx, maxItems)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != maxItems {
t.Errorf("got %d items, want %d", len(results), maxItems)
}
}
func TestPage_HasNextPage(t *testing.T) {
tests := []struct {
name string
page Page[string]
want bool
}{
{
name: "has next page",
page: Page[string]{
Items: []string{"a", "b"},
Offset: 0,
Limit: 2,
Total: 5,
},
want: true,
},
{
name: "no next page",
page: Page[string]{
Items: []string{"d", "e"},
Offset: 3,
Limit: 2,
Total: 5,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.page.HasNextPage(); got != tt.want {
t.Errorf("HasNextPage() = %v, want %v", got, tt.want)
}
})
}
}
func TestPage_HasPrevPage(t *testing.T) {
tests := []struct {
name string
page Page[string]
want bool
}{
{
name: "has previous page",
page: Page[string]{
Offset: 2,
},
want: true,
},
{
name: "no previous page",
page: Page[string]{
Offset: 0,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.page.HasPrevPage(); got != tt.want {
t.Errorf("HasPrevPage() = %v, want %v", got, tt.want)
}
})
}
}
func TestPage_NextOffset(t *testing.T) {
page := Page[string]{
Offset: 10,
Limit: 5,
}
if got := page.NextOffset(); got != 15 {
t.Errorf("NextOffset() = %v, want 15", got)
}
}
func TestPage_PrevOffset(t *testing.T) {
tests := []struct {
name string
offset int
limit int
want int
}{
{"normal case", 10, 5, 5},
{"at beginning", 3, 5, 0},
{"exact boundary", 5, 5, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
page := Page[string]{
Offset: tt.offset,
Limit: tt.limit,
}
if got := page.PrevOffset(); got != tt.want {
t.Errorf("PrevOffset() = %v, want %v", got, tt.want)
}
})
}
}

177
pkg/retry/retry.go Normal file
View File

@@ -0,0 +1,177 @@
// Package retry provides retry logic with exponential backoff for HTTP requests.
package retry
import (
"context"
"math"
"math/rand"
"net/http"
"strconv"
"time"
)
// Config defines the configuration for the retry mechanism.
type Config struct {
// MaxAttempts is the maximum number of retry attempts (including the first try).
// Default: 3
MaxAttempts int
// InitialDelay is the initial delay between retries.
// Default: 1 second
InitialDelay time.Duration
// MaxDelay is the maximum delay between retries.
// Default: 30 seconds
MaxDelay time.Duration
// BackoffFactor is the multiplier for exponential backoff.
// Default: 2.0 (exponential)
BackoffFactor float64
// RetryableErrors are HTTP status codes that should trigger a retry.
// Default: 429, 500, 502, 503, 504
RetryableErrors []int
}
// DefaultConfig returns the default retry configuration.
func DefaultConfig() Config {
return Config{
MaxAttempts: 3,
InitialDelay: time.Second,
MaxDelay: 30 * time.Second,
BackoffFactor: 2.0,
RetryableErrors: []int{429, 500, 502, 503, 504},
}
}
// Retrier handles retry logic for HTTP requests.
type Retrier struct {
config Config
}
// New creates a new Retrier with the given configuration.
func New(config Config) *Retrier {
// Apply defaults if not set
if config.MaxAttempts == 0 {
config.MaxAttempts = DefaultConfig().MaxAttempts
}
if config.InitialDelay == 0 {
config.InitialDelay = DefaultConfig().InitialDelay
}
if config.MaxDelay == 0 {
config.MaxDelay = DefaultConfig().MaxDelay
}
if config.BackoffFactor == 0 {
config.BackoffFactor = DefaultConfig().BackoffFactor
}
if len(config.RetryableErrors) == 0 {
config.RetryableErrors = DefaultConfig().RetryableErrors
}
return &Retrier{
config: config,
}
}
// Do executes the given function with retry logic.
// It returns the response and any error encountered.
func (r *Retrier) Do(ctx context.Context, fn func() (*http.Response, error)) (*http.Response, error) {
var resp *http.Response
var err error
for attempt := 1; attempt <= r.config.MaxAttempts; attempt++ {
// Check if context is cancelled
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
// Execute the function
resp, err = fn()
// If no error and response is not retryable, return success
if err == nil && !r.isRetryable(resp) {
return resp, nil
}
// If this was the last attempt, return the error
if attempt == r.config.MaxAttempts {
if err != nil {
return nil, err
}
return resp, nil
}
// Calculate delay with exponential backoff and jitter
delay := r.calculateDelay(attempt, resp)
// Wait before retrying
select {
case <-time.After(delay):
// Continue to next attempt
case <-ctx.Done():
return nil, ctx.Err()
}
}
return resp, err
}
// isRetryable checks if the HTTP response indicates a retryable error.
func (r *Retrier) isRetryable(resp *http.Response) bool {
if resp == nil {
return false
}
for _, code := range r.config.RetryableErrors {
if resp.StatusCode == code {
return true
}
}
return false
}
// calculateDelay calculates the delay before the next retry attempt.
// It implements exponential backoff with jitter and respects Retry-After header.
func (r *Retrier) calculateDelay(attempt int, resp *http.Response) time.Duration {
// Check for Retry-After header (for 429 rate limiting)
if resp != nil && resp.StatusCode == 429 {
if retryAfter := resp.Header.Get("Retry-After"); retryAfter != "" {
// Try parsing as seconds
if seconds, err := strconv.ParseInt(retryAfter, 10, 64); err == nil {
delay := time.Duration(seconds) * time.Second
if delay > r.config.MaxDelay {
delay = r.config.MaxDelay
}
return delay
}
// Try parsing as HTTP date
if t, err := http.ParseTime(retryAfter); err == nil {
delay := time.Until(t)
if delay < 0 {
delay = 0
}
if delay > r.config.MaxDelay {
delay = r.config.MaxDelay
}
return delay
}
}
}
// Calculate exponential backoff
backoff := float64(r.config.InitialDelay) * math.Pow(r.config.BackoffFactor, float64(attempt-1))
// Apply maximum delay cap
if backoff > float64(r.config.MaxDelay) {
backoff = float64(r.config.MaxDelay)
}
// Add jitter (random factor between 0.5 and 1.0)
jitter := 0.5 + rand.Float64()*0.5
delay := time.Duration(backoff * jitter)
return delay
}

219
pkg/retry/retry_test.go Normal file
View File

@@ -0,0 +1,219 @@
package retry
import (
"context"
"net/http"
"testing"
"time"
)
func TestNew(t *testing.T) {
tests := []struct {
name string
config Config
want Config
}{
{
name: "default config",
config: Config{},
want: DefaultConfig(),
},
{
name: "custom config",
config: Config{
MaxAttempts: 5,
InitialDelay: 2 * time.Second,
MaxDelay: 60 * time.Second,
BackoffFactor: 3.0,
RetryableErrors: []int{500},
},
want: Config{
MaxAttempts: 5,
InitialDelay: 2 * time.Second,
MaxDelay: 60 * time.Second,
BackoffFactor: 3.0,
RetryableErrors: []int{500},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := New(tt.config)
if r.config.MaxAttempts != tt.want.MaxAttempts {
t.Errorf("MaxAttempts = %v, want %v", r.config.MaxAttempts, tt.want.MaxAttempts)
}
if r.config.InitialDelay != tt.want.InitialDelay {
t.Errorf("InitialDelay = %v, want %v", r.config.InitialDelay, tt.want.InitialDelay)
}
})
}
}
func TestRetrier_isRetryable(t *testing.T) {
r := New(DefaultConfig())
tests := []struct {
name string
statusCode int
want bool
}{
{"429 rate limit", 429, true},
{"500 server error", 500, true},
{"502 bad gateway", 502, true},
{"503 service unavailable", 503, true},
{"504 gateway timeout", 504, true},
{"200 ok", 200, false},
{"400 bad request", 400, false},
{"404 not found", 404, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp := &http.Response{StatusCode: tt.statusCode}
if got := r.isRetryable(resp); got != tt.want {
t.Errorf("isRetryable() = %v, want %v", got, tt.want)
}
})
}
}
func TestRetrier_Do_Success(t *testing.T) {
r := New(DefaultConfig())
ctx := context.Background()
attempts := 0
fn := func() (*http.Response, error) {
attempts++
return &http.Response{StatusCode: 200}, nil
}
resp, err := r.Do(ctx, fn)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != 200 {
t.Errorf("status = %d, want 200", resp.StatusCode)
}
if attempts != 1 {
t.Errorf("attempts = %d, want 1", attempts)
}
}
func TestRetrier_Do_Retry(t *testing.T) {
config := Config{
MaxAttempts: 3,
InitialDelay: 10 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
BackoffFactor: 2.0,
RetryableErrors: []int{500},
}
r := New(config)
ctx := context.Background()
attempts := 0
fn := func() (*http.Response, error) {
attempts++
if attempts < 3 {
return &http.Response{StatusCode: 500}, nil
}
return &http.Response{StatusCode: 200}, nil
}
start := time.Now()
resp, err := r.Do(ctx, fn)
elapsed := time.Since(start)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != 200 {
t.Errorf("status = %d, want 200", resp.StatusCode)
}
if attempts != 3 {
t.Errorf("attempts = %d, want 3", attempts)
}
// Should have waited at least once
if elapsed < 10*time.Millisecond {
t.Errorf("elapsed time %v too short, expected at least 10ms", elapsed)
}
}
func TestRetrier_Do_ContextCancellation(t *testing.T) {
config := Config{
MaxAttempts: 5,
InitialDelay: 100 * time.Millisecond,
MaxDelay: 1 * time.Second,
BackoffFactor: 2.0,
RetryableErrors: []int{500},
}
r := New(config)
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
attempts := 0
fn := func() (*http.Response, error) {
attempts++
return &http.Response{StatusCode: 500}, nil
}
_, err := r.Do(ctx, fn)
if err != context.DeadlineExceeded {
t.Errorf("error = %v, want context.DeadlineExceeded", err)
}
if attempts > 2 {
t.Errorf("attempts = %d, should be cancelled early", attempts)
}
}
func TestRetrier_calculateDelay(t *testing.T) {
config := Config{
InitialDelay: time.Second,
MaxDelay: 30 * time.Second,
BackoffFactor: 2.0,
}
r := New(config)
tests := []struct {
name string
attempt int
resp *http.Response
minWant time.Duration
maxWant time.Duration
}{
{
name: "first retry",
attempt: 1,
resp: &http.Response{StatusCode: 500},
minWant: 500 * time.Millisecond, // with jitter min 0.5
maxWant: 1 * time.Second, // with jitter max 1.0
},
{
name: "second retry",
attempt: 2,
resp: &http.Response{StatusCode: 500},
minWant: 1 * time.Second, // 2^1 * 1s * 0.5
maxWant: 2 * time.Second, // 2^1 * 1s * 1.0
},
{
name: "with Retry-After header",
attempt: 1,
resp: &http.Response{
StatusCode: 429,
Header: http.Header{"Retry-After": []string{"5"}},
},
minWant: 5 * time.Second,
maxWant: 5 * time.Second,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
delay := r.calculateDelay(tt.attempt, tt.resp)
if delay < tt.minWant || delay > tt.maxWant {
t.Errorf("delay = %v, want between %v and %v", delay, tt.minWant, tt.maxWant)
}
})
}
}