Compare commits
6 Commits
v0.1.0alph
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6814c7fafb | ||
|
|
727eba0e5e | ||
|
|
79caa168db | ||
|
|
ae8a90d1e8 | ||
|
|
df4978cc04 | ||
|
|
d98b8369fd |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -37,3 +37,5 @@ pkg/models/generated*.go
|
||||
# Temporary files
|
||||
tmp/
|
||||
temp/
|
||||
|
||||
CLAUDE.md
|
||||
|
||||
6
doc.go
6
doc.go
@@ -7,7 +7,7 @@ pagination support, and comprehensive error handling.
|
||||
|
||||
# Installation
|
||||
|
||||
go get github.com/gregor/prefect-go
|
||||
go get git.schultes.dev/schultesdev/prefect-go
|
||||
|
||||
# Quick Start
|
||||
|
||||
@@ -19,8 +19,8 @@ Create a client and interact with the Prefect API:
|
||||
"context"
|
||||
"log"
|
||||
|
||||
"github.com/gregor/prefect-go/pkg/client"
|
||||
"github.com/gregor/prefect-go/pkg/models"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/client"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/models"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gregor/prefect-go/pkg/client"
|
||||
"github.com/gregor/prefect-go/pkg/models"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/client"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/models"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/gregor/prefect-go/pkg/client"
|
||||
"github.com/gregor/prefect-go/pkg/models"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/client"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/models"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gregor/prefect-go/pkg/client"
|
||||
"github.com/gregor/prefect-go/pkg/models"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/client"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/models"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -6,8 +6,8 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/gregor/prefect-go/pkg/client"
|
||||
"github.com/gregor/prefect-go/pkg/models"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/client"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/models"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
2
go.mod
2
go.mod
@@ -1,4 +1,4 @@
|
||||
module github.com/gregor/prefect-go
|
||||
module git.schultes.dev/schultesdev/prefect-go
|
||||
|
||||
go 1.21
|
||||
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
# Integration Tests
|
||||
|
||||
Integration tests require a running Prefect server instance.
|
||||
|
||||
## Running Integration Tests
|
||||
|
||||
### 1. Start a Prefect Server
|
||||
|
||||
```bash
|
||||
prefect server start
|
||||
```
|
||||
|
||||
This will start a Prefect server on `http://localhost:4200`.
|
||||
|
||||
### 2. Run the Integration Tests
|
||||
|
||||
```bash
|
||||
go test -v -tags=integration ./pkg/client/
|
||||
```
|
||||
|
||||
Or using make:
|
||||
|
||||
```bash
|
||||
make test-integration
|
||||
```
|
||||
|
||||
### 3. Custom Server URL
|
||||
|
||||
If your Prefect server is running on a different URL, set the `PREFECT_API_URL` environment variable:
|
||||
|
||||
```bash
|
||||
export PREFECT_API_URL="http://your-server:4200/api"
|
||||
go test -v -tags=integration ./pkg/client/
|
||||
```
|
||||
|
||||
## Test Coverage
|
||||
|
||||
The integration tests cover:
|
||||
|
||||
- **Flow Lifecycle**: Create, get, update, list, and delete flows
|
||||
- **Flow Run Lifecycle**: Create, get, set state, and delete flow runs
|
||||
- **Deployment Lifecycle**: Create, get, pause, resume, and delete deployments
|
||||
- **Pagination**: Manual pagination and iterator-based pagination
|
||||
- **Variables**: Create, get, update, and delete variables
|
||||
- **Admin Endpoints**: Health check and version
|
||||
- **Flow Run Monitoring**: Wait for flow run completion
|
||||
|
||||
## Notes
|
||||
|
||||
- Integration tests use the `integration` build tag to separate them from unit tests
|
||||
- Each test creates and cleans up its own resources
|
||||
- Tests use unique UUIDs in names to avoid conflicts
|
||||
- Some tests may take several seconds to complete (e.g., flow run wait tests)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Refused
|
||||
|
||||
If you see "connection refused" errors, make sure:
|
||||
1. Prefect server is running (`prefect server start`)
|
||||
2. The server is accessible at the configured URL
|
||||
3. No firewall is blocking the connection
|
||||
|
||||
### Test Timeout
|
||||
|
||||
If tests timeout:
|
||||
1. Increase the test timeout: `go test -timeout 5m -tags=integration ./pkg/client/`
|
||||
2. Check server logs for errors
|
||||
3. Ensure the server is not under heavy load
|
||||
406
pkg/client/blocks.go
Normal file
406
pkg/client/blocks.go
Normal file
@@ -0,0 +1,406 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// BlockTypesService handles operations related to block types.
|
||||
type BlockTypesService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// Create creates a new block type.
|
||||
func (s *BlockTypesService) Create(ctx context.Context, req *models.BlockTypeCreate) (*models.BlockType, error) {
|
||||
var blockType models.BlockType
|
||||
if err := s.client.post(ctx, "/block_types/", req, &blockType); err != nil {
|
||||
return nil, fmt.Errorf("failed to create block type: %w", err)
|
||||
}
|
||||
return &blockType, nil
|
||||
}
|
||||
|
||||
// Get retrieves a block type by ID.
|
||||
func (s *BlockTypesService) Get(ctx context.Context, id uuid.UUID) (*models.BlockType, error) {
|
||||
var blockType models.BlockType
|
||||
path := joinPath("/block_types", id.String())
|
||||
if err := s.client.get(ctx, path, &blockType); err != nil {
|
||||
return nil, fmt.Errorf("failed to get block type: %w", err)
|
||||
}
|
||||
return &blockType, nil
|
||||
}
|
||||
|
||||
// GetBySlug retrieves a block type by slug.
|
||||
func (s *BlockTypesService) GetBySlug(ctx context.Context, slug string) (*models.BlockType, error) {
|
||||
var blockType models.BlockType
|
||||
path := joinPath("/block_types/slug", slug)
|
||||
if err := s.client.get(ctx, path, &blockType); err != nil {
|
||||
return nil, fmt.Errorf("failed to get block type by slug: %w", err)
|
||||
}
|
||||
return &blockType, nil
|
||||
}
|
||||
|
||||
// List retrieves block types with optional filtering.
|
||||
func (s *BlockTypesService) List(ctx context.Context, filter *models.BlockTypeFilter, offset, limit int) ([]models.BlockType, error) {
|
||||
type nameFilter struct {
|
||||
Like *string `json:"like,omitempty"`
|
||||
}
|
||||
type slugFilter struct {
|
||||
Any []string `json:"any,omitempty"`
|
||||
}
|
||||
type blockTypeFilterCriteria struct {
|
||||
Name *nameFilter `json:"name,omitempty"`
|
||||
Slug *slugFilter `json:"slug,omitempty"`
|
||||
}
|
||||
type capabilitiesFilter struct {
|
||||
All []string `json:"all,omitempty"`
|
||||
}
|
||||
type blockSchemaFilterCriteria struct {
|
||||
Capabilities *capabilitiesFilter `json:"block_capabilities,omitempty"`
|
||||
}
|
||||
type request struct {
|
||||
BlockTypes *blockTypeFilterCriteria `json:"block_types,omitempty"`
|
||||
BlockSchemas *blockSchemaFilterCriteria `json:"block_schemas,omitempty"`
|
||||
Offset int `json:"offset,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
req := request{Offset: offset, Limit: limit}
|
||||
|
||||
if filter != nil {
|
||||
btf := &blockTypeFilterCriteria{}
|
||||
if filter.Name != nil {
|
||||
btf.Name = &nameFilter{Like: filter.Name}
|
||||
}
|
||||
if len(filter.Slugs) > 0 {
|
||||
btf.Slug = &slugFilter{Any: filter.Slugs}
|
||||
}
|
||||
if btf.Name != nil || btf.Slug != nil {
|
||||
req.BlockTypes = btf
|
||||
}
|
||||
if len(filter.Capabilities) > 0 {
|
||||
req.BlockSchemas = &blockSchemaFilterCriteria{
|
||||
Capabilities: &capabilitiesFilter{All: filter.Capabilities},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var blockTypes []models.BlockType
|
||||
if err := s.client.post(ctx, "/block_types/filter", req, &blockTypes); err != nil {
|
||||
return nil, fmt.Errorf("failed to list block types: %w", err)
|
||||
}
|
||||
return blockTypes, nil
|
||||
}
|
||||
|
||||
// Update updates a block type.
|
||||
func (s *BlockTypesService) Update(ctx context.Context, id uuid.UUID, req *models.BlockTypeUpdate) error {
|
||||
path := joinPath("/block_types", id.String())
|
||||
if err := s.client.patch(ctx, path, req, nil); err != nil {
|
||||
return fmt.Errorf("failed to update block type: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes a block type by ID.
|
||||
func (s *BlockTypesService) Delete(ctx context.Context, id uuid.UUID) error {
|
||||
path := joinPath("/block_types", id.String())
|
||||
if err := s.client.delete(ctx, path); err != nil {
|
||||
return fmt.Errorf("failed to delete block type: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InstallSystemBlockTypes installs the system block types.
|
||||
func (s *BlockTypesService) InstallSystemBlockTypes(ctx context.Context) error {
|
||||
if err := s.client.post(ctx, "/block_types/install_system_block_types", nil, nil); err != nil {
|
||||
return fmt.Errorf("failed to install system block types: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListBlockDocumentsBySlug retrieves all block documents for a block type identified by slug.
|
||||
func (s *BlockTypesService) ListBlockDocumentsBySlug(ctx context.Context, slug string, includeSecrets bool) ([]models.BlockDocument, error) {
|
||||
path := joinPath("/block_types/slug", slug, "block_documents")
|
||||
if includeSecrets {
|
||||
query := url.Values{}
|
||||
query.Set("include_secrets", "true")
|
||||
path = buildPathWithValues(path, query)
|
||||
}
|
||||
var blockDocuments []models.BlockDocument
|
||||
if err := s.client.get(ctx, path, &blockDocuments); err != nil {
|
||||
return nil, fmt.Errorf("failed to list block documents for block type: %w", err)
|
||||
}
|
||||
return blockDocuments, nil
|
||||
}
|
||||
|
||||
// GetBlockDocumentByName retrieves a block document by name for a given block type slug.
|
||||
func (s *BlockTypesService) GetBlockDocumentByName(ctx context.Context, slug, name string, includeSecrets bool) (*models.BlockDocument, error) {
|
||||
path := joinPath("/block_types/slug", slug, "block_documents/name", name)
|
||||
if includeSecrets {
|
||||
query := url.Values{}
|
||||
query.Set("include_secrets", "true")
|
||||
path = buildPathWithValues(path, query)
|
||||
}
|
||||
var blockDocument models.BlockDocument
|
||||
if err := s.client.get(ctx, path, &blockDocument); err != nil {
|
||||
return nil, fmt.Errorf("failed to get block document by name: %w", err)
|
||||
}
|
||||
return &blockDocument, nil
|
||||
}
|
||||
|
||||
// BlockSchemasService handles operations related to block schemas.
|
||||
type BlockSchemasService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// Create creates a new block schema.
|
||||
func (s *BlockSchemasService) Create(ctx context.Context, req *models.BlockSchemaCreate) (*models.BlockSchema, error) {
|
||||
var blockSchema models.BlockSchema
|
||||
if err := s.client.post(ctx, "/block_schemas/", req, &blockSchema); err != nil {
|
||||
return nil, fmt.Errorf("failed to create block schema: %w", err)
|
||||
}
|
||||
return &blockSchema, nil
|
||||
}
|
||||
|
||||
// Get retrieves a block schema by ID.
|
||||
func (s *BlockSchemasService) Get(ctx context.Context, id uuid.UUID) (*models.BlockSchema, error) {
|
||||
var blockSchema models.BlockSchema
|
||||
path := joinPath("/block_schemas", id.String())
|
||||
if err := s.client.get(ctx, path, &blockSchema); err != nil {
|
||||
return nil, fmt.Errorf("failed to get block schema: %w", err)
|
||||
}
|
||||
return &blockSchema, nil
|
||||
}
|
||||
|
||||
// GetByChecksum retrieves a block schema by checksum and optional version.
|
||||
func (s *BlockSchemasService) GetByChecksum(ctx context.Context, checksum string, version string) (*models.BlockSchema, error) {
|
||||
path := joinPath("/block_schemas/checksum", checksum)
|
||||
if version != "" {
|
||||
query := url.Values{}
|
||||
query.Set("version", version)
|
||||
path = buildPathWithValues(path, query)
|
||||
}
|
||||
var blockSchema models.BlockSchema
|
||||
if err := s.client.get(ctx, path, &blockSchema); err != nil {
|
||||
return nil, fmt.Errorf("failed to get block schema by checksum: %w", err)
|
||||
}
|
||||
return &blockSchema, nil
|
||||
}
|
||||
|
||||
// List retrieves block schemas with optional filtering.
|
||||
func (s *BlockSchemasService) List(ctx context.Context, filter *models.BlockSchemaFilter, offset, limit int) ([]models.BlockSchema, error) {
|
||||
type blockTypeIDFilter struct {
|
||||
Any []string `json:"any_,omitempty"`
|
||||
}
|
||||
type capabilitiesFilter struct {
|
||||
All []string `json:"all_,omitempty"`
|
||||
}
|
||||
type versionFilter struct {
|
||||
Any []string `json:"any_,omitempty"`
|
||||
}
|
||||
type blockSchemaFilterCriteria struct {
|
||||
BlockTypeID *blockTypeIDFilter `json:"block_type_id,omitempty"`
|
||||
Capabilities *capabilitiesFilter `json:"block_capabilities,omitempty"`
|
||||
Version *versionFilter `json:"version,omitempty"`
|
||||
}
|
||||
type request struct {
|
||||
BlockSchemas *blockSchemaFilterCriteria `json:"block_schemas,omitempty"`
|
||||
Offset int `json:"offset,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
req := request{Offset: offset, Limit: limit}
|
||||
|
||||
if filter != nil {
|
||||
bsf := &blockSchemaFilterCriteria{}
|
||||
if filter.BlockTypeID != nil {
|
||||
bsf.BlockTypeID = &blockTypeIDFilter{Any: []string{filter.BlockTypeID.String()}}
|
||||
}
|
||||
if len(filter.Capabilities) > 0 {
|
||||
bsf.Capabilities = &capabilitiesFilter{All: filter.Capabilities}
|
||||
}
|
||||
if filter.Version != nil {
|
||||
bsf.Version = &versionFilter{Any: []string{*filter.Version}}
|
||||
}
|
||||
if bsf.BlockTypeID != nil || bsf.Capabilities != nil || bsf.Version != nil {
|
||||
req.BlockSchemas = bsf
|
||||
}
|
||||
}
|
||||
|
||||
var blockSchemas []models.BlockSchema
|
||||
if err := s.client.post(ctx, "/block_schemas/filter", req, &blockSchemas); err != nil {
|
||||
return nil, fmt.Errorf("failed to list block schemas: %w", err)
|
||||
}
|
||||
return blockSchemas, nil
|
||||
}
|
||||
|
||||
// Delete deletes a block schema by ID.
|
||||
func (s *BlockSchemasService) Delete(ctx context.Context, id uuid.UUID) error {
|
||||
path := joinPath("/block_schemas", id.String())
|
||||
if err := s.client.delete(ctx, path); err != nil {
|
||||
return fmt.Errorf("failed to delete block schema: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BlockDocumentsService handles operations related to block documents.
|
||||
type BlockDocumentsService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// Create creates a new block document.
|
||||
func (s *BlockDocumentsService) Create(ctx context.Context, req *models.BlockDocumentCreate) (*models.BlockDocument, error) {
|
||||
var blockDocument models.BlockDocument
|
||||
if err := s.client.post(ctx, "/block_documents/", req, &blockDocument); err != nil {
|
||||
return nil, fmt.Errorf("failed to create block document: %w", err)
|
||||
}
|
||||
return &blockDocument, nil
|
||||
}
|
||||
|
||||
// Get retrieves a block document by ID.
|
||||
func (s *BlockDocumentsService) Get(ctx context.Context, id uuid.UUID, includeSecrets bool) (*models.BlockDocument, error) {
|
||||
path := joinPath("/block_documents", id.String())
|
||||
if includeSecrets {
|
||||
query := url.Values{}
|
||||
query.Set("include_secrets", "true")
|
||||
path = buildPathWithValues(path, query)
|
||||
}
|
||||
var blockDocument models.BlockDocument
|
||||
if err := s.client.get(ctx, path, &blockDocument); err != nil {
|
||||
return nil, fmt.Errorf("failed to get block document: %w", err)
|
||||
}
|
||||
return &blockDocument, nil
|
||||
}
|
||||
|
||||
// List retrieves block documents with optional filtering.
|
||||
func (s *BlockDocumentsService) List(ctx context.Context, filter *models.BlockDocumentFilter, includeSecrets bool, sort models.BlockDocumentSort, offset, limit int) ([]models.BlockDocument, error) {
|
||||
type idFilter struct {
|
||||
Any []string `json:"any_,omitempty"`
|
||||
}
|
||||
type nameFilter struct {
|
||||
Any []string `json:"any_,omitempty"`
|
||||
}
|
||||
type boolFilter struct {
|
||||
Eq bool `json:"eq_"`
|
||||
}
|
||||
type blockDocumentFilterCriteria struct {
|
||||
BlockTypeID *idFilter `json:"block_type_id,omitempty"`
|
||||
Name *nameFilter `json:"name,omitempty"`
|
||||
IsAnonymous *boolFilter `json:"is_anonymous,omitempty"`
|
||||
}
|
||||
type request struct {
|
||||
BlockDocuments *blockDocumentFilterCriteria `json:"block_documents,omitempty"`
|
||||
IncludeSecrets bool `json:"include_secrets,omitempty"`
|
||||
Sort models.BlockDocumentSort `json:"sort,omitempty"`
|
||||
Offset int `json:"offset,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
req := request{
|
||||
IncludeSecrets: includeSecrets,
|
||||
Sort: sort,
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
}
|
||||
|
||||
if filter != nil {
|
||||
bdf := &blockDocumentFilterCriteria{}
|
||||
if filter.BlockTypeID != nil {
|
||||
bdf.BlockTypeID = &idFilter{Any: []string{filter.BlockTypeID.String()}}
|
||||
}
|
||||
if filter.Name != nil {
|
||||
bdf.Name = &nameFilter{Any: []string{*filter.Name}}
|
||||
}
|
||||
if filter.IsAnonymous != nil {
|
||||
bdf.IsAnonymous = &boolFilter{Eq: *filter.IsAnonymous}
|
||||
}
|
||||
if bdf.BlockTypeID != nil || bdf.Name != nil || bdf.IsAnonymous != nil {
|
||||
req.BlockDocuments = bdf
|
||||
}
|
||||
}
|
||||
|
||||
var blockDocuments []models.BlockDocument
|
||||
if err := s.client.post(ctx, "/block_documents/filter", req, &blockDocuments); err != nil {
|
||||
return nil, fmt.Errorf("failed to list block documents: %w", err)
|
||||
}
|
||||
return blockDocuments, nil
|
||||
}
|
||||
|
||||
// Count returns the number of block documents matching the filter.
|
||||
func (s *BlockDocumentsService) Count(ctx context.Context, filter *models.BlockDocumentFilter) (int, error) {
|
||||
type idFilter struct {
|
||||
Any []string `json:"any_,omitempty"`
|
||||
}
|
||||
type nameFilter struct {
|
||||
Any []string `json:"any_,omitempty"`
|
||||
}
|
||||
type boolFilter struct {
|
||||
Eq bool `json:"eq_"`
|
||||
}
|
||||
type blockDocumentFilterCriteria struct {
|
||||
BlockTypeID *idFilter `json:"block_type_id,omitempty"`
|
||||
Name *nameFilter `json:"name,omitempty"`
|
||||
IsAnonymous *boolFilter `json:"is_anonymous,omitempty"`
|
||||
}
|
||||
type request struct {
|
||||
BlockDocuments *blockDocumentFilterCriteria `json:"block_documents,omitempty"`
|
||||
}
|
||||
|
||||
var req request
|
||||
if filter != nil {
|
||||
bdf := &blockDocumentFilterCriteria{}
|
||||
if filter.BlockTypeID != nil {
|
||||
bdf.BlockTypeID = &idFilter{Any: []string{filter.BlockTypeID.String()}}
|
||||
}
|
||||
if filter.Name != nil {
|
||||
bdf.Name = &nameFilter{Any: []string{*filter.Name}}
|
||||
}
|
||||
if filter.IsAnonymous != nil {
|
||||
bdf.IsAnonymous = &boolFilter{Eq: *filter.IsAnonymous}
|
||||
}
|
||||
if bdf.BlockTypeID != nil || bdf.Name != nil || bdf.IsAnonymous != nil {
|
||||
req.BlockDocuments = bdf
|
||||
}
|
||||
}
|
||||
|
||||
var count int
|
||||
if err := s.client.post(ctx, "/block_documents/count", req, &count); err != nil {
|
||||
return 0, fmt.Errorf("failed to count block documents: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Update updates a block document's data.
|
||||
func (s *BlockDocumentsService) Update(ctx context.Context, id uuid.UUID, req *models.BlockDocumentUpdate) error {
|
||||
path := joinPath("/block_documents", id.String())
|
||||
if err := s.client.patch(ctx, path, req, nil); err != nil {
|
||||
return fmt.Errorf("failed to update block document: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete deletes a block document by ID.
|
||||
func (s *BlockDocumentsService) Delete(ctx context.Context, id uuid.UUID) error {
|
||||
path := joinPath("/block_documents", id.String())
|
||||
if err := s.client.delete(ctx, path); err != nil {
|
||||
return fmt.Errorf("failed to delete block document: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BlockCapabilitiesService handles operations related to block capabilities.
|
||||
type BlockCapabilitiesService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// List retrieves all available block capabilities.
|
||||
func (s *BlockCapabilitiesService) List(ctx context.Context) ([]string, error) {
|
||||
var capabilities []string
|
||||
if err := s.client.get(ctx, "/block_capabilities/", &capabilities); err != nil {
|
||||
return nil, fmt.Errorf("failed to list block capabilities: %w", err)
|
||||
}
|
||||
return capabilities, nil
|
||||
}
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gregor/prefect-go/pkg/errors"
|
||||
"github.com/gregor/prefect-go/pkg/retry"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/errors"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/retry"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -31,15 +31,19 @@ type Client struct {
|
||||
userAgent string
|
||||
|
||||
// Services
|
||||
Flows *FlowsService
|
||||
FlowRuns *FlowRunsService
|
||||
Deployments *DeploymentsService
|
||||
TaskRuns *TaskRunsService
|
||||
WorkPools *WorkPoolsService
|
||||
WorkQueues *WorkQueuesService
|
||||
Variables *VariablesService
|
||||
Logs *LogsService
|
||||
Admin *AdminService
|
||||
Flows *FlowsService
|
||||
FlowRuns *FlowRunsService
|
||||
Deployments *DeploymentsService
|
||||
TaskRuns *TaskRunsService
|
||||
WorkPools *WorkPoolsService
|
||||
WorkQueues *WorkQueuesService
|
||||
Variables *VariablesService
|
||||
Logs *LogsService
|
||||
Admin *AdminService
|
||||
BlockTypes *BlockTypesService
|
||||
BlockSchemas *BlockSchemasService
|
||||
BlockDocuments *BlockDocumentsService
|
||||
BlockCapabilities *BlockCapabilitiesService
|
||||
}
|
||||
|
||||
// Option is a functional option for configuring the client.
|
||||
@@ -79,6 +83,10 @@ func NewClient(opts ...Option) (*Client, error) {
|
||||
c.Variables = &VariablesService{client: c}
|
||||
c.Logs = &LogsService{client: c}
|
||||
c.Admin = &AdminService{client: c}
|
||||
c.BlockTypes = &BlockTypesService{client: c}
|
||||
c.BlockSchemas = &BlockSchemasService{client: c}
|
||||
c.BlockDocuments = &BlockDocumentsService{client: c}
|
||||
c.BlockCapabilities = &BlockCapabilitiesService{client: c}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
@@ -140,11 +148,16 @@ func WithUserAgent(ua string) Option {
|
||||
|
||||
// do executes an HTTP request with retry logic.
|
||||
func (c *Client) do(ctx context.Context, method, path string, body, result interface{}) error {
|
||||
// Build full URL
|
||||
u, err := c.baseURL.Parse(path)
|
||||
// Build full URL by joining the base URL path with the service path.
|
||||
// url.URL.Parse() is not used here because it would replace the base path
|
||||
// when the service path is absolute (starts with "/").
|
||||
refURL, err := url.Parse(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse path: %w", err)
|
||||
}
|
||||
u := *c.baseURL
|
||||
u.Path = strings.TrimRight(c.baseURL.Path, "/") + "/" + strings.TrimLeft(refURL.Path, "/")
|
||||
u.RawQuery = refURL.RawQuery
|
||||
|
||||
// Serialize request body if present
|
||||
var reqBody io.Reader
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/models"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/pagination"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gregor/prefect-go/pkg/models"
|
||||
"github.com/gregor/prefect-go/pkg/pagination"
|
||||
)
|
||||
|
||||
// FlowRunsService handles operations related to flow runs.
|
||||
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/models"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/pagination"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gregor/prefect-go/pkg/models"
|
||||
"github.com/gregor/prefect-go/pkg/pagination"
|
||||
)
|
||||
|
||||
// FlowsService handles operations related to flows.
|
||||
|
||||
@@ -1,415 +0,0 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gregor/prefect-go/pkg/client"
|
||||
"github.com/gregor/prefect-go/pkg/errors"
|
||||
"github.com/gregor/prefect-go/pkg/models"
|
||||
)
|
||||
|
||||
var testClient *client.Client
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Get base URL from environment or use default
|
||||
baseURL := os.Getenv("PREFECT_API_URL")
|
||||
if baseURL == "" {
|
||||
baseURL = "http://localhost:4200/api"
|
||||
}
|
||||
|
||||
// Create test client
|
||||
var err error
|
||||
testClient, err = client.NewClient(
|
||||
client.WithBaseURL(baseURL),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestIntegration_FlowLifecycle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a flow
|
||||
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
|
||||
Name: "test-flow-" + uuid.New().String(),
|
||||
Tags: []string{"integration-test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow: %v", err)
|
||||
}
|
||||
defer testClient.Flows.Delete(ctx, flow.ID)
|
||||
|
||||
// Verify flow was created
|
||||
if flow.ID == uuid.Nil {
|
||||
t.Error("Flow ID is nil")
|
||||
}
|
||||
if flow.Name == "" {
|
||||
t.Error("Flow name is empty")
|
||||
}
|
||||
|
||||
// Get the flow
|
||||
retrievedFlow, err := testClient.Flows.Get(ctx, flow.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get flow: %v", err)
|
||||
}
|
||||
if retrievedFlow.ID != flow.ID {
|
||||
t.Errorf("Flow ID mismatch: got %v, want %v", retrievedFlow.ID, flow.ID)
|
||||
}
|
||||
|
||||
// Update the flow
|
||||
newTags := []string{"integration-test", "updated"}
|
||||
updatedFlow, err := testClient.Flows.Update(ctx, flow.ID, &models.FlowUpdate{
|
||||
Tags: &newTags,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update flow: %v", err)
|
||||
}
|
||||
if len(updatedFlow.Tags) != 2 {
|
||||
t.Errorf("Expected 2 tags, got %d", len(updatedFlow.Tags))
|
||||
}
|
||||
|
||||
// List flows
|
||||
flowsPage, err := testClient.Flows.List(ctx, &models.FlowFilter{
|
||||
Tags: []string{"integration-test"},
|
||||
}, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list flows: %v", err)
|
||||
}
|
||||
if len(flowsPage.Results) == 0 {
|
||||
t.Error("Expected at least one flow in results")
|
||||
}
|
||||
|
||||
// Delete the flow
|
||||
if err := testClient.Flows.Delete(ctx, flow.ID); err != nil {
|
||||
t.Fatalf("Failed to delete flow: %v", err)
|
||||
}
|
||||
|
||||
// Verify deletion
|
||||
_, err = testClient.Flows.Get(ctx, flow.ID)
|
||||
if !errors.IsNotFound(err) {
|
||||
t.Errorf("Expected NotFound error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_FlowRunLifecycle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a flow first
|
||||
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
|
||||
Name: "test-flow-run-" + uuid.New().String(),
|
||||
Tags: []string{"integration-test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow: %v", err)
|
||||
}
|
||||
defer testClient.Flows.Delete(ctx, flow.ID)
|
||||
|
||||
// Create a flow run
|
||||
flowRun, err := testClient.FlowRuns.Create(ctx, &models.FlowRunCreate{
|
||||
FlowID: flow.ID,
|
||||
Name: "test-run",
|
||||
Parameters: map[string]interface{}{
|
||||
"test_param": "value",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow run: %v", err)
|
||||
}
|
||||
defer testClient.FlowRuns.Delete(ctx, flowRun.ID)
|
||||
|
||||
// Verify flow run was created
|
||||
if flowRun.ID == uuid.Nil {
|
||||
t.Error("Flow run ID is nil")
|
||||
}
|
||||
if flowRun.FlowID != flow.ID {
|
||||
t.Errorf("Flow ID mismatch: got %v, want %v", flowRun.FlowID, flow.ID)
|
||||
}
|
||||
|
||||
// Get the flow run
|
||||
retrievedRun, err := testClient.FlowRuns.Get(ctx, flowRun.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get flow run: %v", err)
|
||||
}
|
||||
if retrievedRun.ID != flowRun.ID {
|
||||
t.Errorf("Flow run ID mismatch: got %v, want %v", retrievedRun.ID, flowRun.ID)
|
||||
}
|
||||
|
||||
// Set state to RUNNING
|
||||
runningState := models.StateTypeRunning
|
||||
updatedRun, err := testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{
|
||||
Type: runningState,
|
||||
Message: strPtr("Test running"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set state: %v", err)
|
||||
}
|
||||
if updatedRun.StateType == nil || *updatedRun.StateType != runningState {
|
||||
t.Errorf("State type mismatch: got %v, want %v", updatedRun.StateType, runningState)
|
||||
}
|
||||
|
||||
// Set state to COMPLETED
|
||||
completedState := models.StateTypeCompleted
|
||||
completedRun, err := testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{
|
||||
Type: completedState,
|
||||
Message: strPtr("Test completed"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set state: %v", err)
|
||||
}
|
||||
if completedRun.StateType == nil || *completedRun.StateType != completedState {
|
||||
t.Errorf("State type mismatch: got %v, want %v", completedRun.StateType, completedState)
|
||||
}
|
||||
|
||||
// List flow runs
|
||||
runsPage, err := testClient.FlowRuns.List(ctx, &models.FlowRunFilter{
|
||||
FlowID: &flow.ID,
|
||||
}, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list flow runs: %v", err)
|
||||
}
|
||||
if len(runsPage.Results) == 0 {
|
||||
t.Error("Expected at least one flow run in results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_DeploymentLifecycle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a flow first
|
||||
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
|
||||
Name: "test-deployment-flow-" + uuid.New().String(),
|
||||
Tags: []string{"integration-test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow: %v", err)
|
||||
}
|
||||
defer testClient.Flows.Delete(ctx, flow.ID)
|
||||
|
||||
// Create a deployment
|
||||
workPoolName := "default-pool"
|
||||
deployment, err := testClient.Deployments.Create(ctx, &models.DeploymentCreate{
|
||||
Name: "test-deployment",
|
||||
FlowID: flow.ID,
|
||||
WorkPoolName: &workPoolName,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create deployment: %v", err)
|
||||
}
|
||||
defer testClient.Deployments.Delete(ctx, deployment.ID)
|
||||
|
||||
// Verify deployment was created
|
||||
if deployment.ID == uuid.Nil {
|
||||
t.Error("Deployment ID is nil")
|
||||
}
|
||||
if deployment.FlowID != flow.ID {
|
||||
t.Errorf("Flow ID mismatch: got %v, want %v", deployment.FlowID, flow.ID)
|
||||
}
|
||||
|
||||
// Get the deployment
|
||||
retrievedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get deployment: %v", err)
|
||||
}
|
||||
if retrievedDeployment.ID != deployment.ID {
|
||||
t.Errorf("Deployment ID mismatch: got %v, want %v", retrievedDeployment.ID, deployment.ID)
|
||||
}
|
||||
|
||||
// Pause the deployment
|
||||
if err := testClient.Deployments.Pause(ctx, deployment.ID); err != nil {
|
||||
t.Fatalf("Failed to pause deployment: %v", err)
|
||||
}
|
||||
|
||||
// Verify paused
|
||||
pausedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get deployment: %v", err)
|
||||
}
|
||||
if !pausedDeployment.Paused {
|
||||
t.Error("Expected deployment to be paused")
|
||||
}
|
||||
|
||||
// Resume the deployment
|
||||
if err := testClient.Deployments.Resume(ctx, deployment.ID); err != nil {
|
||||
t.Fatalf("Failed to resume deployment: %v", err)
|
||||
}
|
||||
|
||||
// Verify resumed
|
||||
resumedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get deployment: %v", err)
|
||||
}
|
||||
if resumedDeployment.Paused {
|
||||
t.Error("Expected deployment to be resumed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_Pagination(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create multiple flows
|
||||
createdFlows := make([]uuid.UUID, 0)
|
||||
for i := 0; i < 15; i++ {
|
||||
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
|
||||
Name: "pagination-test-" + uuid.New().String(),
|
||||
Tags: []string{"pagination-integration-test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow %d: %v", i, err)
|
||||
}
|
||||
createdFlows = append(createdFlows, flow.ID)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
defer func() {
|
||||
for _, id := range createdFlows {
|
||||
testClient.Flows.Delete(ctx, id)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test manual pagination
|
||||
page1, err := testClient.Flows.List(ctx, &models.FlowFilter{
|
||||
Tags: []string{"pagination-integration-test"},
|
||||
}, 0, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list flows: %v", err)
|
||||
}
|
||||
if len(page1.Results) != 5 {
|
||||
t.Errorf("Expected 5 flows in page 1, got %d", len(page1.Results))
|
||||
}
|
||||
if !page1.HasMore {
|
||||
t.Error("Expected more pages")
|
||||
}
|
||||
|
||||
// Test iterator
|
||||
iter := testClient.Flows.ListAll(ctx, &models.FlowFilter{
|
||||
Tags: []string{"pagination-integration-test"},
|
||||
})
|
||||
|
||||
count := 0
|
||||
for iter.Next(ctx) {
|
||||
count++
|
||||
}
|
||||
if err := iter.Err(); err != nil {
|
||||
t.Fatalf("Iterator error: %v", err)
|
||||
}
|
||||
if count != 15 {
|
||||
t.Errorf("Expected to iterate over 15 flows, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_Variables(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a variable
|
||||
variable, err := testClient.Variables.Create(ctx, &models.VariableCreate{
|
||||
Name: "test-var-" + uuid.New().String(),
|
||||
Value: "test-value",
|
||||
Tags: []string{"integration-test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create variable: %v", err)
|
||||
}
|
||||
defer testClient.Variables.Delete(ctx, variable.ID)
|
||||
|
||||
// Get the variable
|
||||
retrievedVar, err := testClient.Variables.Get(ctx, variable.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get variable: %v", err)
|
||||
}
|
||||
if retrievedVar.Value != "test-value" {
|
||||
t.Errorf("Value mismatch: got %v, want test-value", retrievedVar.Value)
|
||||
}
|
||||
|
||||
// Update the variable
|
||||
newValue := "updated-value"
|
||||
updatedVar, err := testClient.Variables.Update(ctx, variable.ID, &models.VariableUpdate{
|
||||
Value: &newValue,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update variable: %v", err)
|
||||
}
|
||||
if updatedVar.Value != newValue {
|
||||
t.Errorf("Value mismatch: got %v, want %v", updatedVar.Value, newValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_AdminEndpoints(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test health check
|
||||
if err := testClient.Admin.Health(ctx); err != nil {
|
||||
t.Fatalf("Health check failed: %v", err)
|
||||
}
|
||||
|
||||
// Test version
|
||||
version, err := testClient.Admin.Version(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get version: %v", err)
|
||||
}
|
||||
if version == "" {
|
||||
t.Error("Version is empty")
|
||||
}
|
||||
t.Logf("Server version: %s", version)
|
||||
}
|
||||
|
||||
func TestIntegration_FlowRunWait(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a flow
|
||||
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
|
||||
Name: "wait-test-" + uuid.New().String(),
|
||||
Tags: []string{"integration-test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow: %v", err)
|
||||
}
|
||||
defer testClient.Flows.Delete(ctx, flow.ID)
|
||||
|
||||
// Create a flow run
|
||||
flowRun, err := testClient.FlowRuns.Create(ctx, &models.FlowRunCreate{
|
||||
FlowID: flow.ID,
|
||||
Name: "wait-test-run",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow run: %v", err)
|
||||
}
|
||||
defer testClient.FlowRuns.Delete(ctx, flowRun.ID)
|
||||
|
||||
// Simulate completion in background
|
||||
go func() {
|
||||
time.Sleep(2 * time.Second)
|
||||
completedState := models.StateTypeCompleted
|
||||
testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{
|
||||
Type: completedState,
|
||||
Message: strPtr("Test completed"),
|
||||
})
|
||||
}()
|
||||
|
||||
// Wait for completion with timeout
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
finalRun, err := testClient.FlowRuns.Wait(waitCtx, flowRun.ID, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to wait for flow run: %v", err)
|
||||
}
|
||||
|
||||
if finalRun.StateType == nil || *finalRun.StateType != models.StateTypeCompleted {
|
||||
t.Errorf("Expected COMPLETED state, got %v", finalRun.StateType)
|
||||
}
|
||||
}
|
||||
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/models"
|
||||
"git.schultes.dev/schultesdev/prefect-go/pkg/pagination"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gregor/prefect-go/pkg/models"
|
||||
"github.com/gregor/prefect-go/pkg/pagination"
|
||||
)
|
||||
|
||||
// DeploymentsService handles operations related to deployments.
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAPIError_Error(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err *APIError
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "with message",
|
||||
err: &APIError{
|
||||
StatusCode: 404,
|
||||
Message: "Flow not found",
|
||||
},
|
||||
want: "prefect api error (status 404): Flow not found",
|
||||
},
|
||||
{
|
||||
name: "without message",
|
||||
err: &APIError{
|
||||
StatusCode: 500,
|
||||
},
|
||||
want: "prefect api error (status 500)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.err.Error(); got != tt.want {
|
||||
t.Errorf("Error() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_Checks(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
checkFuncs map[string]func(*APIError) bool
|
||||
want map[string]bool
|
||||
}{
|
||||
{
|
||||
name: "404 not found",
|
||||
statusCode: 404,
|
||||
checkFuncs: map[string]func(*APIError) bool{
|
||||
"IsNotFound": (*APIError).IsNotFound,
|
||||
"IsUnauthorized": (*APIError).IsUnauthorized,
|
||||
"IsForbidden": (*APIError).IsForbidden,
|
||||
"IsRateLimited": (*APIError).IsRateLimited,
|
||||
"IsServerError": (*APIError).IsServerError,
|
||||
},
|
||||
want: map[string]bool{
|
||||
"IsNotFound": true,
|
||||
"IsUnauthorized": false,
|
||||
"IsForbidden": false,
|
||||
"IsRateLimited": false,
|
||||
"IsServerError": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "401 unauthorized",
|
||||
statusCode: 401,
|
||||
checkFuncs: map[string]func(*APIError) bool{
|
||||
"IsNotFound": (*APIError).IsNotFound,
|
||||
"IsUnauthorized": (*APIError).IsUnauthorized,
|
||||
"IsForbidden": (*APIError).IsForbidden,
|
||||
"IsRateLimited": (*APIError).IsRateLimited,
|
||||
"IsServerError": (*APIError).IsServerError,
|
||||
},
|
||||
want: map[string]bool{
|
||||
"IsNotFound": false,
|
||||
"IsUnauthorized": true,
|
||||
"IsForbidden": false,
|
||||
"IsRateLimited": false,
|
||||
"IsServerError": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "500 server error",
|
||||
statusCode: 500,
|
||||
checkFuncs: map[string]func(*APIError) bool{
|
||||
"IsNotFound": (*APIError).IsNotFound,
|
||||
"IsUnauthorized": (*APIError).IsUnauthorized,
|
||||
"IsForbidden": (*APIError).IsForbidden,
|
||||
"IsRateLimited": (*APIError).IsRateLimited,
|
||||
"IsServerError": (*APIError).IsServerError,
|
||||
},
|
||||
want: map[string]bool{
|
||||
"IsNotFound": false,
|
||||
"IsUnauthorized": false,
|
||||
"IsForbidden": false,
|
||||
"IsRateLimited": false,
|
||||
"IsServerError": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := &APIError{StatusCode: tt.statusCode}
|
||||
for checkName, checkFunc := range tt.checkFuncs {
|
||||
if got := checkFunc(err); got != tt.want[checkName] {
|
||||
t.Errorf("%s() = %v, want %v", checkName, got, tt.want[checkName])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAPIError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
resp *http.Response
|
||||
wantStatus int
|
||||
wantMsg string
|
||||
}{
|
||||
{
|
||||
name: "nil response",
|
||||
resp: nil,
|
||||
wantStatus: 0,
|
||||
wantMsg: "nil response",
|
||||
},
|
||||
{
|
||||
name: "404 with JSON body",
|
||||
resp: &http.Response{
|
||||
StatusCode: 404,
|
||||
Body: io.NopCloser(strings.NewReader(`{"detail": "Flow not found"}`)),
|
||||
Header: http.Header{},
|
||||
},
|
||||
wantStatus: 404,
|
||||
wantMsg: "Flow not found",
|
||||
},
|
||||
{
|
||||
name: "500 without body",
|
||||
resp: &http.Response{
|
||||
StatusCode: 500,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
Header: http.Header{},
|
||||
},
|
||||
wantStatus: 500,
|
||||
wantMsg: "Internal Server Error",
|
||||
},
|
||||
{
|
||||
name: "with request ID",
|
||||
resp: &http.Response{
|
||||
StatusCode: 400,
|
||||
Body: io.NopCloser(strings.NewReader(`{"detail": "Bad request"}`)),
|
||||
Header: http.Header{
|
||||
"X-Request-ID": []string{"req-123"},
|
||||
},
|
||||
},
|
||||
wantStatus: 400,
|
||||
wantMsg: "Bad request",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := NewAPIError(tt.resp)
|
||||
apiErr, ok := err.(*APIError)
|
||||
if !ok {
|
||||
t.Fatalf("expected *APIError, got %T", err)
|
||||
}
|
||||
if apiErr.StatusCode != tt.wantStatus {
|
||||
t.Errorf("StatusCode = %v, want %v", apiErr.StatusCode, tt.wantStatus)
|
||||
}
|
||||
if apiErr.Message != tt.wantMsg {
|
||||
t.Errorf("Message = %v, want %v", apiErr.Message, tt.wantMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAPIError_Validation(t *testing.T) {
|
||||
resp := &http.Response{
|
||||
StatusCode: 422,
|
||||
Body: io.NopCloser(strings.NewReader(`{
|
||||
"detail": [
|
||||
{
|
||||
"loc": ["body", "name"],
|
||||
"msg": "field required",
|
||||
"type": "value_error.missing"
|
||||
}
|
||||
]
|
||||
}`)),
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
err := NewAPIError(resp)
|
||||
valErr, ok := err.(*ValidationError)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ValidationError, got %T", err)
|
||||
}
|
||||
|
||||
if valErr.StatusCode != 422 {
|
||||
t.Errorf("StatusCode = %v, want 422", valErr.StatusCode)
|
||||
}
|
||||
|
||||
if len(valErr.ValidationErrors) != 1 {
|
||||
t.Fatalf("expected 1 validation error, got %d", len(valErr.ValidationErrors))
|
||||
}
|
||||
|
||||
vd := valErr.ValidationErrors[0]
|
||||
if len(vd.Loc) != 2 || vd.Loc[0] != "body" || vd.Loc[1] != "name" {
|
||||
t.Errorf("Loc = %v, want [body name]", vd.Loc)
|
||||
}
|
||||
if vd.Msg != "field required" {
|
||||
t.Errorf("Msg = %v, want 'field required'", vd.Msg)
|
||||
}
|
||||
if vd.Type != "value_error.missing" {
|
||||
t.Errorf("Type = %v, want 'value_error.missing'", vd.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelperFunctions(t *testing.T) {
|
||||
notFoundErr := &APIError{StatusCode: 404}
|
||||
unauthorizedErr := &APIError{StatusCode: 401}
|
||||
serverErr := &APIError{StatusCode: 500}
|
||||
valErr := &ValidationError{APIError: &APIError{StatusCode: 422}}
|
||||
|
||||
if !IsNotFound(notFoundErr) {
|
||||
t.Error("IsNotFound should return true for 404 error")
|
||||
}
|
||||
if !IsUnauthorized(unauthorizedErr) {
|
||||
t.Error("IsUnauthorized should return true for 401 error")
|
||||
}
|
||||
if !IsServerError(serverErr) {
|
||||
t.Error("IsServerError should return true for 500 error")
|
||||
}
|
||||
if !IsValidationError(valErr) {
|
||||
t.Error("IsValidationError should return true for validation error")
|
||||
}
|
||||
}
|
||||
@@ -291,36 +291,365 @@ type LogCreate struct {
|
||||
TaskRunID *uuid.UUID `json:"task_run_id,omitempty"`
|
||||
}
|
||||
|
||||
// BlockDocumentSort represents sort options for block documents.
|
||||
type BlockDocumentSort string
|
||||
|
||||
const (
|
||||
BlockDocumentSortNameAsc BlockDocumentSort = "NAME_ASC"
|
||||
BlockDocumentSortNameDesc BlockDocumentSort = "NAME_DESC"
|
||||
)
|
||||
|
||||
// BlockType represents a Prefect block type.
|
||||
type BlockType struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Created *time.Time `json:"created"`
|
||||
Updated *time.Time `json:"updated"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
LogoURL *string `json:"logo_url"`
|
||||
DocumentationURL *string `json:"documentation_url"`
|
||||
Description *string `json:"description"`
|
||||
CodeExample *string `json:"code_example"`
|
||||
IsProtected bool `json:"is_protected"`
|
||||
}
|
||||
|
||||
// BlockTypeCreate represents the request to create a block type.
|
||||
type BlockTypeCreate struct {
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
LogoURL *string `json:"logo_url,omitempty"`
|
||||
DocumentationURL *string `json:"documentation_url,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
CodeExample *string `json:"code_example,omitempty"`
|
||||
}
|
||||
|
||||
// BlockTypeUpdate represents the request to update a block type.
|
||||
type BlockTypeUpdate struct {
|
||||
LogoURL *string `json:"logo_url,omitempty"`
|
||||
DocumentationURL *string `json:"documentation_url,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
CodeExample *string `json:"code_example,omitempty"`
|
||||
}
|
||||
|
||||
// BlockTypeFilter represents filter criteria for querying block types.
|
||||
type BlockTypeFilter struct {
|
||||
Name *string // Filter by name (partial match)
|
||||
Slugs []string // Filter by slugs
|
||||
Capabilities []string // Filter by block capabilities
|
||||
}
|
||||
|
||||
// BlockSchema represents a Prefect block schema.
|
||||
type BlockSchema struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Created *time.Time `json:"created"`
|
||||
Updated *time.Time `json:"updated"`
|
||||
Checksum string `json:"checksum"`
|
||||
Fields map[string]interface{} `json:"fields"`
|
||||
BlockTypeID *uuid.UUID `json:"block_type_id"`
|
||||
BlockType *BlockType `json:"block_type"`
|
||||
Capabilities []string `json:"capabilities"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// BlockSchemaCreate represents the request to create a block schema.
|
||||
type BlockSchemaCreate struct {
|
||||
Fields map[string]interface{} `json:"fields,omitempty"`
|
||||
BlockTypeID uuid.UUID `json:"block_type_id"`
|
||||
Capabilities []string `json:"capabilities,omitempty"`
|
||||
Version *string `json:"version,omitempty"`
|
||||
}
|
||||
|
||||
// BlockSchemaFilter represents filter criteria for querying block schemas.
|
||||
type BlockSchemaFilter struct {
|
||||
BlockTypeID *uuid.UUID // Filter by block type ID
|
||||
Capabilities []string // Filter by required capabilities
|
||||
Version *string // Filter by schema version
|
||||
}
|
||||
|
||||
// BlockDocument represents a Prefect block document.
|
||||
type BlockDocument struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Created *time.Time `json:"created"`
|
||||
Updated *time.Time `json:"updated"`
|
||||
Name *string `json:"name"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
BlockSchemaID uuid.UUID `json:"block_schema_id"`
|
||||
BlockSchema *BlockSchema `json:"block_schema"`
|
||||
BlockTypeID uuid.UUID `json:"block_type_id"`
|
||||
BlockTypeName *string `json:"block_type_name"`
|
||||
BlockType *BlockType `json:"block_type"`
|
||||
BlockDocumentReferences map[string]interface{} `json:"block_document_references"`
|
||||
IsAnonymous bool `json:"is_anonymous"`
|
||||
}
|
||||
|
||||
// BlockDocumentCreate represents the request to create a block document.
|
||||
type BlockDocumentCreate struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
BlockSchemaID uuid.UUID `json:"block_schema_id"`
|
||||
BlockTypeID uuid.UUID `json:"block_type_id"`
|
||||
IsAnonymous bool `json:"is_anonymous,omitempty"`
|
||||
}
|
||||
|
||||
// BlockDocumentUpdate represents the request to update a block document.
|
||||
type BlockDocumentUpdate struct {
|
||||
BlockSchemaID *uuid.UUID `json:"block_schema_id,omitempty"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
MergeExisting *bool `json:"merge_existing_data,omitempty"`
|
||||
}
|
||||
|
||||
// BlockDocumentFilter represents filter criteria for querying block documents.
|
||||
type BlockDocumentFilter struct {
|
||||
BlockTypeID *uuid.UUID // Filter by block type ID
|
||||
Name *string // Filter by document name
|
||||
IsAnonymous *bool // Filter by anonymity
|
||||
}
|
||||
|
||||
// parseOptionalTime parses a time string, returning nil for empty strings.
|
||||
// Supports RFC3339Nano and RFC3339 formats used by the Prefect API.
|
||||
func parseOptionalTime(s string) (*time.Time, error) {
|
||||
if s == "" {
|
||||
return nil, nil
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339Nano, s)
|
||||
if err != nil {
|
||||
t, err = time.Parse(time.RFC3339, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
// optTime handles JSON unmarshaling of nullable time fields that may be
|
||||
// represented as empty strings or JSON null by the Prefect API.
|
||||
type optTime struct {
|
||||
V *time.Time
|
||||
}
|
||||
|
||||
func (o *optTime) UnmarshalJSON(data []byte) error {
|
||||
if string(data) == "null" {
|
||||
o.V = nil
|
||||
return nil
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
v, err := parseOptionalTime(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.V = v
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for time fields.
|
||||
func (f *Flow) UnmarshalJSON(data []byte) error {
|
||||
type Alias Flow
|
||||
aux := &struct {
|
||||
Created string `json:"created"`
|
||||
Updated string `json:"updated"`
|
||||
Created optTime `json:"created"`
|
||||
Updated optTime `json:"updated"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(f),
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
if err := json.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if aux.Created != "" {
|
||||
t, err := time.Parse(time.RFC3339, aux.Created)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Created = &t
|
||||
}
|
||||
|
||||
if aux.Updated != "" {
|
||||
t, err := time.Parse(time.RFC3339, aux.Updated)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Updated = &t
|
||||
}
|
||||
|
||||
f.Created = aux.Created.V
|
||||
f.Updated = aux.Updated.V
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for time fields.
|
||||
func (fr *FlowRun) UnmarshalJSON(data []byte) error {
|
||||
type Alias FlowRun
|
||||
aux := &struct {
|
||||
Created optTime `json:"created"`
|
||||
Updated optTime `json:"updated"`
|
||||
ExpectedStartTime optTime `json:"expected_start_time"`
|
||||
NextScheduledStartTime optTime `json:"next_scheduled_start_time"`
|
||||
StartTime optTime `json:"start_time"`
|
||||
EndTime optTime `json:"end_time"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(fr),
|
||||
}
|
||||
if err := json.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
fr.Created = aux.Created.V
|
||||
fr.Updated = aux.Updated.V
|
||||
fr.ExpectedStartTime = aux.ExpectedStartTime.V
|
||||
fr.NextScheduledStartTime = aux.NextScheduledStartTime.V
|
||||
fr.StartTime = aux.StartTime.V
|
||||
fr.EndTime = aux.EndTime.V
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for time fields.
|
||||
func (d *Deployment) UnmarshalJSON(data []byte) error {
|
||||
type Alias Deployment
|
||||
aux := &struct {
|
||||
Created optTime `json:"created"`
|
||||
Updated optTime `json:"updated"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(d),
|
||||
}
|
||||
if err := json.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
d.Created = aux.Created.V
|
||||
d.Updated = aux.Updated.V
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for time fields.
|
||||
func (tr *TaskRun) UnmarshalJSON(data []byte) error {
|
||||
type Alias TaskRun
|
||||
aux := &struct {
|
||||
Created optTime `json:"created"`
|
||||
Updated optTime `json:"updated"`
|
||||
StartTime optTime `json:"start_time"`
|
||||
EndTime optTime `json:"end_time"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(tr),
|
||||
}
|
||||
if err := json.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
tr.Created = aux.Created.V
|
||||
tr.Updated = aux.Updated.V
|
||||
tr.StartTime = aux.StartTime.V
|
||||
tr.EndTime = aux.EndTime.V
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for time fields.
|
||||
func (wp *WorkPool) UnmarshalJSON(data []byte) error {
|
||||
type Alias WorkPool
|
||||
aux := &struct {
|
||||
Created optTime `json:"created"`
|
||||
Updated optTime `json:"updated"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(wp),
|
||||
}
|
||||
if err := json.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
wp.Created = aux.Created.V
|
||||
wp.Updated = aux.Updated.V
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for time fields.
|
||||
func (wq *WorkQueue) UnmarshalJSON(data []byte) error {
|
||||
type Alias WorkQueue
|
||||
aux := &struct {
|
||||
Created optTime `json:"created"`
|
||||
Updated optTime `json:"updated"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(wq),
|
||||
}
|
||||
if err := json.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
wq.Created = aux.Created.V
|
||||
wq.Updated = aux.Updated.V
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for time fields.
|
||||
func (v *Variable) UnmarshalJSON(data []byte) error {
|
||||
type Alias Variable
|
||||
aux := &struct {
|
||||
Created optTime `json:"created"`
|
||||
Updated optTime `json:"updated"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(v),
|
||||
}
|
||||
if err := json.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
v.Created = aux.Created.V
|
||||
v.Updated = aux.Updated.V
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for time fields.
|
||||
func (l *Log) UnmarshalJSON(data []byte) error {
|
||||
type Alias Log
|
||||
aux := &struct {
|
||||
Created optTime `json:"created"`
|
||||
Updated optTime `json:"updated"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(l),
|
||||
}
|
||||
if err := json.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
l.Created = aux.Created.V
|
||||
l.Updated = aux.Updated.V
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for time fields.
|
||||
func (bt *BlockType) UnmarshalJSON(data []byte) error {
|
||||
type Alias BlockType
|
||||
aux := &struct {
|
||||
Created optTime `json:"created"`
|
||||
Updated optTime `json:"updated"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(bt),
|
||||
}
|
||||
if err := json.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
bt.Created = aux.Created.V
|
||||
bt.Updated = aux.Updated.V
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for time fields.
|
||||
func (bs *BlockSchema) UnmarshalJSON(data []byte) error {
|
||||
type Alias BlockSchema
|
||||
aux := &struct {
|
||||
Created optTime `json:"created"`
|
||||
Updated optTime `json:"updated"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(bs),
|
||||
}
|
||||
if err := json.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
bs.Created = aux.Created.V
|
||||
bs.Updated = aux.Updated.V
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for time fields.
|
||||
func (bd *BlockDocument) UnmarshalJSON(data []byte) error {
|
||||
type Alias BlockDocument
|
||||
aux := &struct {
|
||||
Created optTime `json:"created"`
|
||||
Updated optTime `json:"updated"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(bd),
|
||||
}
|
||||
if err := json.Unmarshal(data, aux); err != nil {
|
||||
return err
|
||||
}
|
||||
bd.Created = aux.Created.V
|
||||
bd.Updated = aux.Updated.V
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,339 +0,0 @@
|
||||
package pagination
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewIterator(t *testing.T) {
|
||||
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
limit int
|
||||
wantLimit int
|
||||
}{
|
||||
{"default limit", 0, 100},
|
||||
{"custom limit", 50, 50},
|
||||
{"negative limit", -10, 100},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
iter := NewIterator(fetchFunc, tt.limit)
|
||||
if iter.limit != tt.wantLimit {
|
||||
t.Errorf("limit = %d, want %d", iter.limit, tt.wantLimit)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIterator_Next(t *testing.T) {
|
||||
items := []string{"item1", "item2", "item3", "item4", "item5"}
|
||||
|
||||
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
|
||||
if offset >= len(items) {
|
||||
return &PaginatedResponse[string]{
|
||||
Results: []string{},
|
||||
Count: len(items),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
HasMore: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
end := offset + limit
|
||||
if end > len(items) {
|
||||
end = len(items)
|
||||
}
|
||||
|
||||
return &PaginatedResponse[string]{
|
||||
Results: items[offset:end],
|
||||
Count: len(items),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
HasMore: end < len(items),
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
iter := NewIterator(fetchFunc, 2)
|
||||
|
||||
var results []string
|
||||
for iter.Next(ctx) {
|
||||
if item := iter.Value(); item != nil {
|
||||
results = append(results, *item)
|
||||
}
|
||||
}
|
||||
|
||||
if err := iter.Err(); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != len(items) {
|
||||
t.Errorf("got %d items, want %d", len(results), len(items))
|
||||
}
|
||||
|
||||
for i, want := range items {
|
||||
if i >= len(results) || results[i] != want {
|
||||
t.Errorf("item[%d] = %v, want %v", i, results[i], want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIterator_Error(t *testing.T) {
|
||||
expectedErr := errors.New("fetch error")
|
||||
|
||||
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
|
||||
return nil, expectedErr
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
iter := NewIterator(fetchFunc, 10)
|
||||
|
||||
if iter.Next(ctx) {
|
||||
t.Error("Next should return false on error")
|
||||
}
|
||||
|
||||
if err := iter.Err(); err != expectedErr {
|
||||
t.Errorf("Err() = %v, want %v", err, expectedErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIterator_EmptyResults(t *testing.T) {
|
||||
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
|
||||
return &PaginatedResponse[string]{
|
||||
Results: []string{},
|
||||
Count: 0,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
HasMore: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
iter := NewIterator(fetchFunc, 10)
|
||||
|
||||
if iter.Next(ctx) {
|
||||
t.Error("Next should return false for empty results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIterator_Reset(t *testing.T) {
|
||||
items := []string{"item1", "item2", "item3"}
|
||||
|
||||
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
|
||||
if offset >= len(items) {
|
||||
return &PaginatedResponse[string]{
|
||||
Results: []string{},
|
||||
HasMore: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &PaginatedResponse[string]{
|
||||
Results: items[offset:],
|
||||
HasMore: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
iter := NewIterator(fetchFunc, 10)
|
||||
|
||||
// First iteration
|
||||
count1 := 0
|
||||
for iter.Next(ctx) {
|
||||
count1++
|
||||
}
|
||||
|
||||
// Reset and iterate again
|
||||
iter.Reset()
|
||||
count2 := 0
|
||||
for iter.Next(ctx) {
|
||||
count2++
|
||||
}
|
||||
|
||||
if count1 != count2 {
|
||||
t.Errorf("counts don't match after reset: %d vs %d", count1, count2)
|
||||
}
|
||||
if count1 != len(items) {
|
||||
t.Errorf("got %d items, want %d", count1, len(items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIterator_Collect(t *testing.T) {
|
||||
items := []string{"item1", "item2", "item3", "item4", "item5"}
|
||||
|
||||
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
|
||||
if offset >= len(items) {
|
||||
return &PaginatedResponse[string]{
|
||||
Results: []string{},
|
||||
HasMore: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
end := offset + limit
|
||||
if end > len(items) {
|
||||
end = len(items)
|
||||
}
|
||||
|
||||
return &PaginatedResponse[string]{
|
||||
Results: items[offset:end],
|
||||
HasMore: end < len(items),
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
iter := NewIterator(fetchFunc, 2)
|
||||
|
||||
results, err := iter.Collect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != len(items) {
|
||||
t.Errorf("got %d items, want %d", len(results), len(items))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIterator_CollectWithLimit(t *testing.T) {
|
||||
items := []string{"item1", "item2", "item3", "item4", "item5"}
|
||||
|
||||
fetchFunc := func(ctx context.Context, offset, limit int) (*PaginatedResponse[string], error) {
|
||||
if offset >= len(items) {
|
||||
return &PaginatedResponse[string]{
|
||||
Results: []string{},
|
||||
HasMore: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
end := offset + limit
|
||||
if end > len(items) {
|
||||
end = len(items)
|
||||
}
|
||||
|
||||
return &PaginatedResponse[string]{
|
||||
Results: items[offset:end],
|
||||
HasMore: end < len(items),
|
||||
}, nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
iter := NewIterator(fetchFunc, 2)
|
||||
|
||||
maxItems := 3
|
||||
results, err := iter.CollectWithLimit(ctx, maxItems)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != maxItems {
|
||||
t.Errorf("got %d items, want %d", len(results), maxItems)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPage_HasNextPage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page Page[string]
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "has next page",
|
||||
page: Page[string]{
|
||||
Items: []string{"a", "b"},
|
||||
Offset: 0,
|
||||
Limit: 2,
|
||||
Total: 5,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no next page",
|
||||
page: Page[string]{
|
||||
Items: []string{"d", "e"},
|
||||
Offset: 3,
|
||||
Limit: 2,
|
||||
Total: 5,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.page.HasNextPage(); got != tt.want {
|
||||
t.Errorf("HasNextPage() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPage_HasPrevPage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page Page[string]
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "has previous page",
|
||||
page: Page[string]{
|
||||
Offset: 2,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no previous page",
|
||||
page: Page[string]{
|
||||
Offset: 0,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.page.HasPrevPage(); got != tt.want {
|
||||
t.Errorf("HasPrevPage() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPage_NextOffset(t *testing.T) {
|
||||
page := Page[string]{
|
||||
Offset: 10,
|
||||
Limit: 5,
|
||||
}
|
||||
|
||||
if got := page.NextOffset(); got != 15 {
|
||||
t.Errorf("NextOffset() = %v, want 15", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPage_PrevOffset(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
offset int
|
||||
limit int
|
||||
want int
|
||||
}{
|
||||
{"normal case", 10, 5, 5},
|
||||
{"at beginning", 3, 5, 0},
|
||||
{"exact boundary", 5, 5, 0},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
page := Page[string]{
|
||||
Offset: tt.offset,
|
||||
Limit: tt.limit,
|
||||
}
|
||||
if got := page.PrevOffset(); got != tt.want {
|
||||
t.Errorf("PrevOffset() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,219 +0,0 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config Config
|
||||
want Config
|
||||
}{
|
||||
{
|
||||
name: "default config",
|
||||
config: Config{},
|
||||
want: DefaultConfig(),
|
||||
},
|
||||
{
|
||||
name: "custom config",
|
||||
config: Config{
|
||||
MaxAttempts: 5,
|
||||
InitialDelay: 2 * time.Second,
|
||||
MaxDelay: 60 * time.Second,
|
||||
BackoffFactor: 3.0,
|
||||
RetryableErrors: []int{500},
|
||||
},
|
||||
want: Config{
|
||||
MaxAttempts: 5,
|
||||
InitialDelay: 2 * time.Second,
|
||||
MaxDelay: 60 * time.Second,
|
||||
BackoffFactor: 3.0,
|
||||
RetryableErrors: []int{500},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := New(tt.config)
|
||||
if r.config.MaxAttempts != tt.want.MaxAttempts {
|
||||
t.Errorf("MaxAttempts = %v, want %v", r.config.MaxAttempts, tt.want.MaxAttempts)
|
||||
}
|
||||
if r.config.InitialDelay != tt.want.InitialDelay {
|
||||
t.Errorf("InitialDelay = %v, want %v", r.config.InitialDelay, tt.want.InitialDelay)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrier_isRetryable(t *testing.T) {
|
||||
r := New(DefaultConfig())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
want bool
|
||||
}{
|
||||
{"429 rate limit", 429, true},
|
||||
{"500 server error", 500, true},
|
||||
{"502 bad gateway", 502, true},
|
||||
{"503 service unavailable", 503, true},
|
||||
{"504 gateway timeout", 504, true},
|
||||
{"200 ok", 200, false},
|
||||
{"400 bad request", 400, false},
|
||||
{"404 not found", 404, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: tt.statusCode}
|
||||
if got := r.isRetryable(resp); got != tt.want {
|
||||
t.Errorf("isRetryable() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrier_Do_Success(t *testing.T) {
|
||||
r := New(DefaultConfig())
|
||||
ctx := context.Background()
|
||||
|
||||
attempts := 0
|
||||
fn := func() (*http.Response, error) {
|
||||
attempts++
|
||||
return &http.Response{StatusCode: 200}, nil
|
||||
}
|
||||
|
||||
resp, err := r.Do(ctx, fn)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("status = %d, want 200", resp.StatusCode)
|
||||
}
|
||||
if attempts != 1 {
|
||||
t.Errorf("attempts = %d, want 1", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrier_Do_Retry(t *testing.T) {
|
||||
config := Config{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
BackoffFactor: 2.0,
|
||||
RetryableErrors: []int{500},
|
||||
}
|
||||
r := New(config)
|
||||
ctx := context.Background()
|
||||
|
||||
attempts := 0
|
||||
fn := func() (*http.Response, error) {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return &http.Response{StatusCode: 500}, nil
|
||||
}
|
||||
return &http.Response{StatusCode: 200}, nil
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := r.Do(ctx, fn)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
t.Errorf("status = %d, want 200", resp.StatusCode)
|
||||
}
|
||||
if attempts != 3 {
|
||||
t.Errorf("attempts = %d, want 3", attempts)
|
||||
}
|
||||
// Should have waited at least once
|
||||
if elapsed < 10*time.Millisecond {
|
||||
t.Errorf("elapsed time %v too short, expected at least 10ms", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrier_Do_ContextCancellation(t *testing.T) {
|
||||
config := Config{
|
||||
MaxAttempts: 5,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 1 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
RetryableErrors: []int{500},
|
||||
}
|
||||
r := New(config)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
attempts := 0
|
||||
fn := func() (*http.Response, error) {
|
||||
attempts++
|
||||
return &http.Response{StatusCode: 500}, nil
|
||||
}
|
||||
|
||||
_, err := r.Do(ctx, fn)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("error = %v, want context.DeadlineExceeded", err)
|
||||
}
|
||||
if attempts > 2 {
|
||||
t.Errorf("attempts = %d, should be cancelled early", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrier_calculateDelay(t *testing.T) {
|
||||
config := Config{
|
||||
InitialDelay: time.Second,
|
||||
MaxDelay: 30 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
}
|
||||
r := New(config)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attempt int
|
||||
resp *http.Response
|
||||
minWant time.Duration
|
||||
maxWant time.Duration
|
||||
}{
|
||||
{
|
||||
name: "first retry",
|
||||
attempt: 1,
|
||||
resp: &http.Response{StatusCode: 500},
|
||||
minWant: 500 * time.Millisecond, // with jitter min 0.5
|
||||
maxWant: 1 * time.Second, // with jitter max 1.0
|
||||
},
|
||||
{
|
||||
name: "second retry",
|
||||
attempt: 2,
|
||||
resp: &http.Response{StatusCode: 500},
|
||||
minWant: 1 * time.Second, // 2^1 * 1s * 0.5
|
||||
maxWant: 2 * time.Second, // 2^1 * 1s * 1.0
|
||||
},
|
||||
{
|
||||
name: "with Retry-After header",
|
||||
attempt: 1,
|
||||
resp: &http.Response{
|
||||
StatusCode: 429,
|
||||
Header: http.Header{"Retry-After": []string{"5"}},
|
||||
},
|
||||
minWant: 5 * time.Second,
|
||||
maxWant: 5 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
delay := r.calculateDelay(tt.attempt, tt.resp)
|
||||
if delay < tt.minWant || delay > tt.maxWant {
|
||||
t.Errorf("delay = %v, want between %v and %v", delay, tt.minWant, tt.maxWant)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user