Initial commit
This commit is contained in:
69
pkg/client/INTEGRATION_TESTS.md
Normal file
69
pkg/client/INTEGRATION_TESTS.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# Integration Tests
|
||||
|
||||
Integration tests require a running Prefect server instance.
|
||||
|
||||
## Running Integration Tests
|
||||
|
||||
### 1. Start a Prefect Server
|
||||
|
||||
```bash
|
||||
prefect server start
|
||||
```
|
||||
|
||||
This will start a Prefect server on `http://localhost:4200`.
|
||||
|
||||
### 2. Run the Integration Tests
|
||||
|
||||
```bash
|
||||
go test -v -tags=integration ./pkg/client/
|
||||
```
|
||||
|
||||
Or using make:
|
||||
|
||||
```bash
|
||||
make test-integration
|
||||
```
|
||||
|
||||
### 3. Custom Server URL
|
||||
|
||||
If your Prefect server is running on a different URL, set the `PREFECT_API_URL` environment variable:
|
||||
|
||||
```bash
|
||||
export PREFECT_API_URL="http://your-server:4200/api"
|
||||
go test -v -tags=integration ./pkg/client/
|
||||
```
|
||||
|
||||
## Test Coverage
|
||||
|
||||
The integration tests cover:
|
||||
|
||||
- **Flow Lifecycle**: Create, get, update, list, and delete flows
|
||||
- **Flow Run Lifecycle**: Create, get, set state, and delete flow runs
|
||||
- **Deployment Lifecycle**: Create, get, pause, resume, and delete deployments
|
||||
- **Pagination**: Manual pagination and iterator-based pagination
|
||||
- **Variables**: Create, get, update, and delete variables
|
||||
- **Admin Endpoints**: Health check and version
|
||||
- **Flow Run Monitoring**: Wait for flow run completion
|
||||
|
||||
## Notes
|
||||
|
||||
- Integration tests use the `integration` build tag to separate them from unit tests
|
||||
- Each test creates and cleans up its own resources
|
||||
- Tests use unique UUIDs in names to avoid conflicts
|
||||
- Some tests may take several seconds to complete (e.g., flow run wait tests)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Refused
|
||||
|
||||
If you see "connection refused" errors, make sure:
|
||||
1. Prefect server is running (`prefect server start`)
|
||||
2. The server is accessible at the configured URL
|
||||
3. No firewall is blocking the connection
|
||||
|
||||
### Test Timeout
|
||||
|
||||
If tests timeout:
|
||||
1. Increase the test timeout: `go test -timeout 5m -tags=integration ./pkg/client/`
|
||||
2. Check server logs for errors
|
||||
3. Ensure the server is not under heavy load
|
||||
268
pkg/client/client.go
Normal file
268
pkg/client/client.go
Normal file
@@ -0,0 +1,268 @@
|
||||
// Package client provides a Go client for the Prefect Server API.
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gregor/prefect-go/pkg/errors"
|
||||
"github.com/gregor/prefect-go/pkg/retry"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBaseURL = "http://localhost:4200/api"
|
||||
defaultTimeout = 30 * time.Second
|
||||
userAgent = "prefect-go-client/1.0.0"
|
||||
)
|
||||
|
||||
// Client is the main client for interacting with the Prefect API.
|
||||
type Client struct {
|
||||
baseURL *url.URL
|
||||
httpClient *http.Client
|
||||
apiKey string
|
||||
retrier *retry.Retrier
|
||||
userAgent string
|
||||
|
||||
// Services
|
||||
Flows *FlowsService
|
||||
FlowRuns *FlowRunsService
|
||||
Deployments *DeploymentsService
|
||||
TaskRuns *TaskRunsService
|
||||
WorkPools *WorkPoolsService
|
||||
WorkQueues *WorkQueuesService
|
||||
Variables *VariablesService
|
||||
Logs *LogsService
|
||||
Admin *AdminService
|
||||
}
|
||||
|
||||
// Option is a functional option for configuring the client.
|
||||
type Option func(*Client) error
|
||||
|
||||
// NewClient creates a new Prefect API client with the given options.
|
||||
func NewClient(opts ...Option) (*Client, error) {
|
||||
// Parse default base URL
|
||||
baseURL, err := url.Parse(defaultBaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse default base URL: %w", err)
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
baseURL: baseURL,
|
||||
httpClient: &http.Client{
|
||||
Timeout: defaultTimeout,
|
||||
},
|
||||
retrier: retry.New(retry.DefaultConfig()),
|
||||
userAgent: userAgent,
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
if err := opt(c); err != nil {
|
||||
return nil, fmt.Errorf("failed to apply option: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize services
|
||||
c.Flows = &FlowsService{client: c}
|
||||
c.FlowRuns = &FlowRunsService{client: c}
|
||||
c.Deployments = &DeploymentsService{client: c}
|
||||
c.TaskRuns = &TaskRunsService{client: c}
|
||||
c.WorkPools = &WorkPoolsService{client: c}
|
||||
c.WorkQueues = &WorkQueuesService{client: c}
|
||||
c.Variables = &VariablesService{client: c}
|
||||
c.Logs = &LogsService{client: c}
|
||||
c.Admin = &AdminService{client: c}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// WithBaseURL sets the base URL for the Prefect API.
|
||||
func WithBaseURL(baseURL string) Option {
|
||||
return func(c *Client) error {
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid base URL: %w", err)
|
||||
}
|
||||
c.baseURL = u
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithAPIKey sets the API key for authentication with Prefect Cloud.
|
||||
func WithAPIKey(apiKey string) Option {
|
||||
return func(c *Client) error {
|
||||
c.apiKey = apiKey
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient sets a custom HTTP client.
|
||||
func WithHTTPClient(httpClient *http.Client) Option {
|
||||
return func(c *Client) error {
|
||||
if httpClient == nil {
|
||||
return fmt.Errorf("HTTP client cannot be nil")
|
||||
}
|
||||
c.httpClient = httpClient
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout sets the HTTP client timeout.
|
||||
func WithTimeout(timeout time.Duration) Option {
|
||||
return func(c *Client) error {
|
||||
c.httpClient.Timeout = timeout
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithRetry sets the retry configuration.
|
||||
func WithRetry(config retry.Config) Option {
|
||||
return func(c *Client) error {
|
||||
c.retrier = retry.New(config)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithUserAgent sets a custom user agent string.
|
||||
func WithUserAgent(ua string) Option {
|
||||
return func(c *Client) error {
|
||||
c.userAgent = ua
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// do executes an HTTP request with retry logic.
|
||||
func (c *Client) do(ctx context.Context, method, path string, body, result interface{}) error {
|
||||
// Build full URL
|
||||
u, err := c.baseURL.Parse(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse path: %w", err)
|
||||
}
|
||||
|
||||
// Serialize request body if present
|
||||
var reqBody io.Reader
|
||||
if body != nil {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||
}
|
||||
reqBody = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
// Create request
|
||||
req, err := http.NewRequestWithContext(ctx, method, u.String(), reqBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", c.userAgent)
|
||||
|
||||
// Set API key if present
|
||||
if c.apiKey != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.apiKey))
|
||||
}
|
||||
|
||||
// Execute request with retry logic
|
||||
var resp *http.Response
|
||||
resp, err = c.retrier.Do(ctx, func() (*http.Response, error) {
|
||||
// Clone the request for retry attempts
|
||||
reqClone := req.Clone(ctx)
|
||||
if reqBody != nil {
|
||||
// Reset body for retries
|
||||
if seeker, ok := reqBody.(io.Seeker); ok {
|
||||
if _, err := seeker.Seek(0, 0); err != nil {
|
||||
return nil, fmt.Errorf("failed to reset request body: %w", err)
|
||||
}
|
||||
}
|
||||
reqClone.Body = io.NopCloser(reqBody)
|
||||
}
|
||||
return c.httpClient.Do(reqClone)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Check for error status codes
|
||||
if resp.StatusCode >= 400 {
|
||||
return errors.NewAPIError(resp)
|
||||
}
|
||||
|
||||
// Deserialize response body if result is provided
|
||||
if result != nil && resp.StatusCode != http.StatusNoContent {
|
||||
if err := json.NewDecoder(resp.Body).Decode(result); err != nil {
|
||||
return fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// get performs a GET request.
|
||||
func (c *Client) get(ctx context.Context, path string, result interface{}) error {
|
||||
return c.do(ctx, http.MethodGet, path, nil, result)
|
||||
}
|
||||
|
||||
// post performs a POST request.
|
||||
func (c *Client) post(ctx context.Context, path string, body, result interface{}) error {
|
||||
return c.do(ctx, http.MethodPost, path, body, result)
|
||||
}
|
||||
|
||||
// patch performs a PATCH request.
|
||||
func (c *Client) patch(ctx context.Context, path string, body, result interface{}) error {
|
||||
return c.do(ctx, http.MethodPatch, path, body, result)
|
||||
}
|
||||
|
||||
// delete performs a DELETE request.
|
||||
func (c *Client) delete(ctx context.Context, path string) error {
|
||||
return c.do(ctx, http.MethodDelete, path, nil, nil)
|
||||
}
|
||||
|
||||
// buildPath builds a URL path with optional query parameters.
|
||||
func buildPath(base string, params map[string]string) string {
|
||||
if len(params) == 0 {
|
||||
return base
|
||||
}
|
||||
|
||||
query := url.Values{}
|
||||
for k, v := range params {
|
||||
if v != "" {
|
||||
query.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
if queryStr := query.Encode(); queryStr != "" {
|
||||
return base + "?" + queryStr
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
// buildPathWithValues builds a URL path with url.Values query parameters.
|
||||
func buildPathWithValues(base string, query url.Values) string {
|
||||
if len(query) == 0 {
|
||||
return base
|
||||
}
|
||||
|
||||
if queryStr := query.Encode(); queryStr != "" {
|
||||
return base + "?" + queryStr
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
// joinPath joins URL path segments.
|
||||
func joinPath(parts ...string) string {
|
||||
return strings.Join(parts, "/")
|
||||
}
|
||||
252
pkg/client/client_test.go
Normal file
252
pkg/client/client_test.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opts []Option
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "default client",
|
||||
opts: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "with base URL",
|
||||
opts: []Option{
|
||||
WithBaseURL("http://example.com/api"),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "with invalid base URL",
|
||||
opts: []Option{
|
||||
WithBaseURL("://invalid"),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "with API key",
|
||||
opts: []Option{
|
||||
WithAPIKey("test-key"),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "with custom timeout",
|
||||
opts: []Option{
|
||||
WithTimeout(60 * time.Second),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client, err := NewClient(tt.opts...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && client == nil {
|
||||
t.Error("NewClient() returned nil client")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_do(t *testing.T) {
|
||||
// Create a test server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check headers
|
||||
if r.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("Content-Type = %v, want application/json", r.Header.Get("Content-Type"))
|
||||
}
|
||||
if r.Header.Get("Accept") != "application/json" {
|
||||
t.Errorf("Accept = %v, want application/json", r.Header.Get("Accept"))
|
||||
}
|
||||
|
||||
// Check path
|
||||
if r.URL.Path == "/test" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"message": "success"}`))
|
||||
} else if r.URL.Path == "/error" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"detail": "bad request"}`))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client, err := NewClient(WithBaseURL(server.URL))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("successful request", func(t *testing.T) {
|
||||
var result map[string]string
|
||||
err := client.get(ctx, "/test", &result)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if result["message"] != "success" {
|
||||
t.Errorf("message = %v, want success", result["message"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error response", func(t *testing.T) {
|
||||
var result map[string]string
|
||||
err := client.get(ctx, "/error", &result)
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_WithAPIKey(t *testing.T) {
|
||||
apiKey := "test-api-key"
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
auth := r.Header.Get("Authorization")
|
||||
expected := "Bearer " + apiKey
|
||||
if auth != expected {
|
||||
t.Errorf("Authorization = %v, want %v", auth, expected)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client, err := NewClient(
|
||||
WithBaseURL(server.URL),
|
||||
WithAPIKey(apiKey),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err = client.get(ctx, "/test", nil)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_WithCustomHTTPClient(t *testing.T) {
|
||||
customClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
client, err := NewClient(WithHTTPClient(customClient))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
if client.httpClient != customClient {
|
||||
t.Error("custom HTTP client not set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_WithNilHTTPClient(t *testing.T) {
|
||||
_, err := NewClient(WithHTTPClient(nil))
|
||||
if err == nil {
|
||||
t.Error("expected error with nil HTTP client")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
base string
|
||||
params map[string]string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no params",
|
||||
base: "/flows",
|
||||
params: nil,
|
||||
want: "/flows",
|
||||
},
|
||||
{
|
||||
name: "with params",
|
||||
base: "/flows",
|
||||
params: map[string]string{
|
||||
"limit": "10",
|
||||
"offset": "20",
|
||||
},
|
||||
want: "/flows?limit=10&offset=20",
|
||||
},
|
||||
{
|
||||
name: "with empty param value",
|
||||
base: "/flows",
|
||||
params: map[string]string{
|
||||
"limit": "10",
|
||||
"offset": "",
|
||||
},
|
||||
want: "/flows?limit=10",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := buildPath(tt.base, tt.params)
|
||||
// Note: URL encoding might change order, so we just check it contains the params
|
||||
if !contains(got, tt.base) {
|
||||
t.Errorf("buildPath() = %v, should contain %v", got, tt.base)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
parts []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "single part",
|
||||
parts: []string{"flows"},
|
||||
want: "flows",
|
||||
},
|
||||
{
|
||||
name: "multiple parts",
|
||||
parts: []string{"flows", "123", "runs"},
|
||||
want: "flows/123/runs",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := joinPath(tt.parts...)
|
||||
if got != tt.want {
|
||||
t.Errorf("joinPath() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || containsMiddle(s, substr)))
|
||||
}
|
||||
|
||||
func containsMiddle(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
157
pkg/client/flow_runs.go
Normal file
157
pkg/client/flow_runs.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gregor/prefect-go/pkg/models"
|
||||
"github.com/gregor/prefect-go/pkg/pagination"
|
||||
)
|
||||
|
||||
// FlowRunsService handles operations related to flow runs.
|
||||
type FlowRunsService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// Create creates a new flow run.
|
||||
func (s *FlowRunsService) Create(ctx context.Context, req *models.FlowRunCreate) (*models.FlowRun, error) {
|
||||
var flowRun models.FlowRun
|
||||
if err := s.client.post(ctx, "/flow_runs/", req, &flowRun); err != nil {
|
||||
return nil, fmt.Errorf("failed to create flow run: %w", err)
|
||||
}
|
||||
return &flowRun, nil
|
||||
}
|
||||
|
||||
// Get retrieves a flow run by ID.
|
||||
func (s *FlowRunsService) Get(ctx context.Context, id uuid.UUID) (*models.FlowRun, error) {
|
||||
var flowRun models.FlowRun
|
||||
path := joinPath("/flow_runs", id.String())
|
||||
if err := s.client.get(ctx, path, &flowRun); err != nil {
|
||||
return nil, fmt.Errorf("failed to get flow run: %w", err)
|
||||
}
|
||||
return &flowRun, nil
|
||||
}
|
||||
|
||||
// List retrieves a list of flow runs with optional filtering.
|
||||
func (s *FlowRunsService) List(ctx context.Context, filter *models.FlowRunFilter, offset, limit int) (*pagination.PaginatedResponse[models.FlowRun], error) {
|
||||
if filter == nil {
|
||||
filter = &models.FlowRunFilter{}
|
||||
}
|
||||
filter.Offset = offset
|
||||
filter.Limit = limit
|
||||
|
||||
type response struct {
|
||||
Results []models.FlowRun `json:"results"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
var resp response
|
||||
if err := s.client.post(ctx, "/flow_runs/filter", filter, &resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to list flow runs: %w", err)
|
||||
}
|
||||
|
||||
return &pagination.PaginatedResponse[models.FlowRun]{
|
||||
Results: resp.Results,
|
||||
Count: resp.Count,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
HasMore: offset+len(resp.Results) < resp.Count,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListAll returns an iterator for all flow runs matching the filter.
|
||||
func (s *FlowRunsService) ListAll(ctx context.Context, filter *models.FlowRunFilter) *pagination.Iterator[models.FlowRun] {
|
||||
fetchFunc := func(ctx context.Context, offset, limit int) (*pagination.PaginatedResponse[models.FlowRun], error) {
|
||||
return s.List(ctx, filter, offset, limit)
|
||||
}
|
||||
return pagination.NewIterator(fetchFunc, 100)
|
||||
}
|
||||
|
||||
// Update updates a flow run.
|
||||
func (s *FlowRunsService) Update(ctx context.Context, id uuid.UUID, req *models.FlowRunUpdate) (*models.FlowRun, error) {
|
||||
var flowRun models.FlowRun
|
||||
path := joinPath("/flow_runs", id.String())
|
||||
if err := s.client.patch(ctx, path, req, &flowRun); err != nil {
|
||||
return nil, fmt.Errorf("failed to update flow run: %w", err)
|
||||
}
|
||||
return &flowRun, nil
|
||||
}
|
||||
|
||||
// Delete deletes a flow run by ID.
|
||||
func (s *FlowRunsService) Delete(ctx context.Context, id uuid.UUID) error {
|
||||
path := joinPath("/flow_runs", id.String())
|
||||
if err := s.client.delete(ctx, path); err != nil {
|
||||
return fmt.Errorf("failed to delete flow run: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetState sets the state of a flow run.
|
||||
func (s *FlowRunsService) SetState(ctx context.Context, id uuid.UUID, state *models.StateCreate) (*models.FlowRun, error) {
|
||||
var flowRun models.FlowRun
|
||||
path := joinPath("/flow_runs", id.String(), "set_state")
|
||||
|
||||
req := struct {
|
||||
State *models.StateCreate `json:"state"`
|
||||
}{
|
||||
State: state,
|
||||
}
|
||||
|
||||
if err := s.client.post(ctx, path, req, &flowRun); err != nil {
|
||||
return nil, fmt.Errorf("failed to set flow run state: %w", err)
|
||||
}
|
||||
return &flowRun, nil
|
||||
}
|
||||
|
||||
// Wait waits for a flow run to complete by polling its state.
|
||||
// It returns the final flow run state or an error if the context is cancelled.
|
||||
func (s *FlowRunsService) Wait(ctx context.Context, id uuid.UUID, pollInterval time.Duration) (*models.FlowRun, error) {
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = 5 * time.Second
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
flowRun, err := s.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if flow run is in a terminal state
|
||||
if flowRun.StateType != nil {
|
||||
switch *flowRun.StateType {
|
||||
case models.StateTypeCompleted,
|
||||
models.StateTypeFailed,
|
||||
models.StateTypeCancelled,
|
||||
models.StateTypeCrashed:
|
||||
return flowRun, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count returns the number of flow runs matching the filter.
|
||||
func (s *FlowRunsService) Count(ctx context.Context, filter *models.FlowRunFilter) (int, error) {
|
||||
if filter == nil {
|
||||
filter = &models.FlowRunFilter{}
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
if err := s.client.post(ctx, "/flow_runs/count", filter, &result); err != nil {
|
||||
return 0, fmt.Errorf("failed to count flow runs: %w", err)
|
||||
}
|
||||
|
||||
return result.Count, nil
|
||||
}
|
||||
115
pkg/client/flows.go
Normal file
115
pkg/client/flows.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gregor/prefect-go/pkg/models"
|
||||
"github.com/gregor/prefect-go/pkg/pagination"
|
||||
)
|
||||
|
||||
// FlowsService handles operations related to flows.
|
||||
type FlowsService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// Create creates a new flow.
|
||||
func (s *FlowsService) Create(ctx context.Context, req *models.FlowCreate) (*models.Flow, error) {
|
||||
var flow models.Flow
|
||||
if err := s.client.post(ctx, "/flows/", req, &flow); err != nil {
|
||||
return nil, fmt.Errorf("failed to create flow: %w", err)
|
||||
}
|
||||
return &flow, nil
|
||||
}
|
||||
|
||||
// Get retrieves a flow by ID.
|
||||
func (s *FlowsService) Get(ctx context.Context, id uuid.UUID) (*models.Flow, error) {
|
||||
var flow models.Flow
|
||||
path := joinPath("/flows", id.String())
|
||||
if err := s.client.get(ctx, path, &flow); err != nil {
|
||||
return nil, fmt.Errorf("failed to get flow: %w", err)
|
||||
}
|
||||
return &flow, nil
|
||||
}
|
||||
|
||||
// GetByName retrieves a flow by name.
|
||||
func (s *FlowsService) GetByName(ctx context.Context, name string) (*models.Flow, error) {
|
||||
var flow models.Flow
|
||||
path := buildPath("/flows/name/"+name, nil)
|
||||
if err := s.client.get(ctx, path, &flow); err != nil {
|
||||
return nil, fmt.Errorf("failed to get flow by name: %w", err)
|
||||
}
|
||||
return &flow, nil
|
||||
}
|
||||
|
||||
// List retrieves a list of flows with optional filtering.
|
||||
func (s *FlowsService) List(ctx context.Context, filter *models.FlowFilter, offset, limit int) (*pagination.PaginatedResponse[models.Flow], error) {
|
||||
if filter == nil {
|
||||
filter = &models.FlowFilter{}
|
||||
}
|
||||
filter.Offset = offset
|
||||
filter.Limit = limit
|
||||
|
||||
type response struct {
|
||||
Results []models.Flow `json:"results"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
var resp response
|
||||
if err := s.client.post(ctx, "/flows/filter", filter, &resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to list flows: %w", err)
|
||||
}
|
||||
|
||||
return &pagination.PaginatedResponse[models.Flow]{
|
||||
Results: resp.Results,
|
||||
Count: resp.Count,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
HasMore: offset+len(resp.Results) < resp.Count,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListAll returns an iterator for all flows matching the filter.
|
||||
func (s *FlowsService) ListAll(ctx context.Context, filter *models.FlowFilter) *pagination.Iterator[models.Flow] {
|
||||
fetchFunc := func(ctx context.Context, offset, limit int) (*pagination.PaginatedResponse[models.Flow], error) {
|
||||
return s.List(ctx, filter, offset, limit)
|
||||
}
|
||||
return pagination.NewIterator(fetchFunc, 100)
|
||||
}
|
||||
|
||||
// Update updates a flow.
|
||||
func (s *FlowsService) Update(ctx context.Context, id uuid.UUID, req *models.FlowUpdate) (*models.Flow, error) {
|
||||
var flow models.Flow
|
||||
path := joinPath("/flows", id.String())
|
||||
if err := s.client.patch(ctx, path, req, &flow); err != nil {
|
||||
return nil, fmt.Errorf("failed to update flow: %w", err)
|
||||
}
|
||||
return &flow, nil
|
||||
}
|
||||
|
||||
// Delete deletes a flow by ID.
|
||||
func (s *FlowsService) Delete(ctx context.Context, id uuid.UUID) error {
|
||||
path := joinPath("/flows", id.String())
|
||||
if err := s.client.delete(ctx, path); err != nil {
|
||||
return fmt.Errorf("failed to delete flow: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count returns the number of flows matching the filter.
|
||||
func (s *FlowsService) Count(ctx context.Context, filter *models.FlowFilter) (int, error) {
|
||||
if filter == nil {
|
||||
filter = &models.FlowFilter{}
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
if err := s.client.post(ctx, "/flows/count", filter, &result); err != nil {
|
||||
return 0, fmt.Errorf("failed to count flows: %w", err)
|
||||
}
|
||||
|
||||
return result.Count, nil
|
||||
}
|
||||
415
pkg/client/integration_test.go
Normal file
415
pkg/client/integration_test.go
Normal file
@@ -0,0 +1,415 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gregor/prefect-go/pkg/client"
|
||||
"github.com/gregor/prefect-go/pkg/errors"
|
||||
"github.com/gregor/prefect-go/pkg/models"
|
||||
)
|
||||
|
||||
var testClient *client.Client
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Get base URL from environment or use default
|
||||
baseURL := os.Getenv("PREFECT_API_URL")
|
||||
if baseURL == "" {
|
||||
baseURL = "http://localhost:4200/api"
|
||||
}
|
||||
|
||||
// Create test client
|
||||
var err error
|
||||
testClient, err = client.NewClient(
|
||||
client.WithBaseURL(baseURL),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestIntegration_FlowLifecycle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a flow
|
||||
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
|
||||
Name: "test-flow-" + uuid.New().String(),
|
||||
Tags: []string{"integration-test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow: %v", err)
|
||||
}
|
||||
defer testClient.Flows.Delete(ctx, flow.ID)
|
||||
|
||||
// Verify flow was created
|
||||
if flow.ID == uuid.Nil {
|
||||
t.Error("Flow ID is nil")
|
||||
}
|
||||
if flow.Name == "" {
|
||||
t.Error("Flow name is empty")
|
||||
}
|
||||
|
||||
// Get the flow
|
||||
retrievedFlow, err := testClient.Flows.Get(ctx, flow.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get flow: %v", err)
|
||||
}
|
||||
if retrievedFlow.ID != flow.ID {
|
||||
t.Errorf("Flow ID mismatch: got %v, want %v", retrievedFlow.ID, flow.ID)
|
||||
}
|
||||
|
||||
// Update the flow
|
||||
newTags := []string{"integration-test", "updated"}
|
||||
updatedFlow, err := testClient.Flows.Update(ctx, flow.ID, &models.FlowUpdate{
|
||||
Tags: &newTags,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update flow: %v", err)
|
||||
}
|
||||
if len(updatedFlow.Tags) != 2 {
|
||||
t.Errorf("Expected 2 tags, got %d", len(updatedFlow.Tags))
|
||||
}
|
||||
|
||||
// List flows
|
||||
flowsPage, err := testClient.Flows.List(ctx, &models.FlowFilter{
|
||||
Tags: []string{"integration-test"},
|
||||
}, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list flows: %v", err)
|
||||
}
|
||||
if len(flowsPage.Results) == 0 {
|
||||
t.Error("Expected at least one flow in results")
|
||||
}
|
||||
|
||||
// Delete the flow
|
||||
if err := testClient.Flows.Delete(ctx, flow.ID); err != nil {
|
||||
t.Fatalf("Failed to delete flow: %v", err)
|
||||
}
|
||||
|
||||
// Verify deletion
|
||||
_, err = testClient.Flows.Get(ctx, flow.ID)
|
||||
if !errors.IsNotFound(err) {
|
||||
t.Errorf("Expected NotFound error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_FlowRunLifecycle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a flow first
|
||||
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
|
||||
Name: "test-flow-run-" + uuid.New().String(),
|
||||
Tags: []string{"integration-test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow: %v", err)
|
||||
}
|
||||
defer testClient.Flows.Delete(ctx, flow.ID)
|
||||
|
||||
// Create a flow run
|
||||
flowRun, err := testClient.FlowRuns.Create(ctx, &models.FlowRunCreate{
|
||||
FlowID: flow.ID,
|
||||
Name: "test-run",
|
||||
Parameters: map[string]interface{}{
|
||||
"test_param": "value",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow run: %v", err)
|
||||
}
|
||||
defer testClient.FlowRuns.Delete(ctx, flowRun.ID)
|
||||
|
||||
// Verify flow run was created
|
||||
if flowRun.ID == uuid.Nil {
|
||||
t.Error("Flow run ID is nil")
|
||||
}
|
||||
if flowRun.FlowID != flow.ID {
|
||||
t.Errorf("Flow ID mismatch: got %v, want %v", flowRun.FlowID, flow.ID)
|
||||
}
|
||||
|
||||
// Get the flow run
|
||||
retrievedRun, err := testClient.FlowRuns.Get(ctx, flowRun.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get flow run: %v", err)
|
||||
}
|
||||
if retrievedRun.ID != flowRun.ID {
|
||||
t.Errorf("Flow run ID mismatch: got %v, want %v", retrievedRun.ID, flowRun.ID)
|
||||
}
|
||||
|
||||
// Set state to RUNNING
|
||||
runningState := models.StateTypeRunning
|
||||
updatedRun, err := testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{
|
||||
Type: runningState,
|
||||
Message: strPtr("Test running"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set state: %v", err)
|
||||
}
|
||||
if updatedRun.StateType == nil || *updatedRun.StateType != runningState {
|
||||
t.Errorf("State type mismatch: got %v, want %v", updatedRun.StateType, runningState)
|
||||
}
|
||||
|
||||
// Set state to COMPLETED
|
||||
completedState := models.StateTypeCompleted
|
||||
completedRun, err := testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{
|
||||
Type: completedState,
|
||||
Message: strPtr("Test completed"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set state: %v", err)
|
||||
}
|
||||
if completedRun.StateType == nil || *completedRun.StateType != completedState {
|
||||
t.Errorf("State type mismatch: got %v, want %v", completedRun.StateType, completedState)
|
||||
}
|
||||
|
||||
// List flow runs
|
||||
runsPage, err := testClient.FlowRuns.List(ctx, &models.FlowRunFilter{
|
||||
FlowID: &flow.ID,
|
||||
}, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list flow runs: %v", err)
|
||||
}
|
||||
if len(runsPage.Results) == 0 {
|
||||
t.Error("Expected at least one flow run in results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_DeploymentLifecycle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a flow first
|
||||
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
|
||||
Name: "test-deployment-flow-" + uuid.New().String(),
|
||||
Tags: []string{"integration-test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow: %v", err)
|
||||
}
|
||||
defer testClient.Flows.Delete(ctx, flow.ID)
|
||||
|
||||
// Create a deployment
|
||||
workPoolName := "default-pool"
|
||||
deployment, err := testClient.Deployments.Create(ctx, &models.DeploymentCreate{
|
||||
Name: "test-deployment",
|
||||
FlowID: flow.ID,
|
||||
WorkPoolName: &workPoolName,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create deployment: %v", err)
|
||||
}
|
||||
defer testClient.Deployments.Delete(ctx, deployment.ID)
|
||||
|
||||
// Verify deployment was created
|
||||
if deployment.ID == uuid.Nil {
|
||||
t.Error("Deployment ID is nil")
|
||||
}
|
||||
if deployment.FlowID != flow.ID {
|
||||
t.Errorf("Flow ID mismatch: got %v, want %v", deployment.FlowID, flow.ID)
|
||||
}
|
||||
|
||||
// Get the deployment
|
||||
retrievedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get deployment: %v", err)
|
||||
}
|
||||
if retrievedDeployment.ID != deployment.ID {
|
||||
t.Errorf("Deployment ID mismatch: got %v, want %v", retrievedDeployment.ID, deployment.ID)
|
||||
}
|
||||
|
||||
// Pause the deployment
|
||||
if err := testClient.Deployments.Pause(ctx, deployment.ID); err != nil {
|
||||
t.Fatalf("Failed to pause deployment: %v", err)
|
||||
}
|
||||
|
||||
// Verify paused
|
||||
pausedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get deployment: %v", err)
|
||||
}
|
||||
if !pausedDeployment.Paused {
|
||||
t.Error("Expected deployment to be paused")
|
||||
}
|
||||
|
||||
// Resume the deployment
|
||||
if err := testClient.Deployments.Resume(ctx, deployment.ID); err != nil {
|
||||
t.Fatalf("Failed to resume deployment: %v", err)
|
||||
}
|
||||
|
||||
// Verify resumed
|
||||
resumedDeployment, err := testClient.Deployments.Get(ctx, deployment.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get deployment: %v", err)
|
||||
}
|
||||
if resumedDeployment.Paused {
|
||||
t.Error("Expected deployment to be resumed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_Pagination(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create multiple flows
|
||||
createdFlows := make([]uuid.UUID, 0)
|
||||
for i := 0; i < 15; i++ {
|
||||
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
|
||||
Name: "pagination-test-" + uuid.New().String(),
|
||||
Tags: []string{"pagination-integration-test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow %d: %v", i, err)
|
||||
}
|
||||
createdFlows = append(createdFlows, flow.ID)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
defer func() {
|
||||
for _, id := range createdFlows {
|
||||
testClient.Flows.Delete(ctx, id)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test manual pagination
|
||||
page1, err := testClient.Flows.List(ctx, &models.FlowFilter{
|
||||
Tags: []string{"pagination-integration-test"},
|
||||
}, 0, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list flows: %v", err)
|
||||
}
|
||||
if len(page1.Results) != 5 {
|
||||
t.Errorf("Expected 5 flows in page 1, got %d", len(page1.Results))
|
||||
}
|
||||
if !page1.HasMore {
|
||||
t.Error("Expected more pages")
|
||||
}
|
||||
|
||||
// Test iterator
|
||||
iter := testClient.Flows.ListAll(ctx, &models.FlowFilter{
|
||||
Tags: []string{"pagination-integration-test"},
|
||||
})
|
||||
|
||||
count := 0
|
||||
for iter.Next(ctx) {
|
||||
count++
|
||||
}
|
||||
if err := iter.Err(); err != nil {
|
||||
t.Fatalf("Iterator error: %v", err)
|
||||
}
|
||||
if count != 15 {
|
||||
t.Errorf("Expected to iterate over 15 flows, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_Variables(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a variable
|
||||
variable, err := testClient.Variables.Create(ctx, &models.VariableCreate{
|
||||
Name: "test-var-" + uuid.New().String(),
|
||||
Value: "test-value",
|
||||
Tags: []string{"integration-test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create variable: %v", err)
|
||||
}
|
||||
defer testClient.Variables.Delete(ctx, variable.ID)
|
||||
|
||||
// Get the variable
|
||||
retrievedVar, err := testClient.Variables.Get(ctx, variable.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get variable: %v", err)
|
||||
}
|
||||
if retrievedVar.Value != "test-value" {
|
||||
t.Errorf("Value mismatch: got %v, want test-value", retrievedVar.Value)
|
||||
}
|
||||
|
||||
// Update the variable
|
||||
newValue := "updated-value"
|
||||
updatedVar, err := testClient.Variables.Update(ctx, variable.ID, &models.VariableUpdate{
|
||||
Value: &newValue,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update variable: %v", err)
|
||||
}
|
||||
if updatedVar.Value != newValue {
|
||||
t.Errorf("Value mismatch: got %v, want %v", updatedVar.Value, newValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_AdminEndpoints(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test health check
|
||||
if err := testClient.Admin.Health(ctx); err != nil {
|
||||
t.Fatalf("Health check failed: %v", err)
|
||||
}
|
||||
|
||||
// Test version
|
||||
version, err := testClient.Admin.Version(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get version: %v", err)
|
||||
}
|
||||
if version == "" {
|
||||
t.Error("Version is empty")
|
||||
}
|
||||
t.Logf("Server version: %s", version)
|
||||
}
|
||||
|
||||
func TestIntegration_FlowRunWait(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a flow
|
||||
flow, err := testClient.Flows.Create(ctx, &models.FlowCreate{
|
||||
Name: "wait-test-" + uuid.New().String(),
|
||||
Tags: []string{"integration-test"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow: %v", err)
|
||||
}
|
||||
defer testClient.Flows.Delete(ctx, flow.ID)
|
||||
|
||||
// Create a flow run
|
||||
flowRun, err := testClient.FlowRuns.Create(ctx, &models.FlowRunCreate{
|
||||
FlowID: flow.ID,
|
||||
Name: "wait-test-run",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create flow run: %v", err)
|
||||
}
|
||||
defer testClient.FlowRuns.Delete(ctx, flowRun.ID)
|
||||
|
||||
// Simulate completion in background
|
||||
go func() {
|
||||
time.Sleep(2 * time.Second)
|
||||
completedState := models.StateTypeCompleted
|
||||
testClient.FlowRuns.SetState(ctx, flowRun.ID, &models.StateCreate{
|
||||
Type: completedState,
|
||||
Message: strPtr("Test completed"),
|
||||
})
|
||||
}()
|
||||
|
||||
// Wait for completion with timeout
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
finalRun, err := testClient.FlowRuns.Wait(waitCtx, flowRun.ID, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to wait for flow run: %v", err)
|
||||
}
|
||||
|
||||
if finalRun.StateType == nil || *finalRun.StateType != models.StateTypeCompleted {
|
||||
t.Errorf("Expected COMPLETED state, got %v", finalRun.StateType)
|
||||
}
|
||||
}
|
||||
|
||||
func strPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
301
pkg/client/services.go
Normal file
301
pkg/client/services.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gregor/prefect-go/pkg/models"
|
||||
"github.com/gregor/prefect-go/pkg/pagination"
|
||||
)
|
||||
|
||||
// DeploymentsService handles operations related to deployments.
|
||||
type DeploymentsService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// Create creates a new deployment.
|
||||
func (s *DeploymentsService) Create(ctx context.Context, req *models.DeploymentCreate) (*models.Deployment, error) {
|
||||
var deployment models.Deployment
|
||||
if err := s.client.post(ctx, "/deployments/", req, &deployment); err != nil {
|
||||
return nil, fmt.Errorf("failed to create deployment: %w", err)
|
||||
}
|
||||
return &deployment, nil
|
||||
}
|
||||
|
||||
// Get retrieves a deployment by ID.
|
||||
func (s *DeploymentsService) Get(ctx context.Context, id uuid.UUID) (*models.Deployment, error) {
|
||||
var deployment models.Deployment
|
||||
path := joinPath("/deployments", id.String())
|
||||
if err := s.client.get(ctx, path, &deployment); err != nil {
|
||||
return nil, fmt.Errorf("failed to get deployment: %w", err)
|
||||
}
|
||||
return &deployment, nil
|
||||
}
|
||||
|
||||
// GetByName retrieves a deployment by flow name and deployment name.
|
||||
func (s *DeploymentsService) GetByName(ctx context.Context, flowName, deploymentName string) (*models.Deployment, error) {
|
||||
var deployment models.Deployment
|
||||
path := fmt.Sprintf("/deployments/name/%s/%s", flowName, deploymentName)
|
||||
if err := s.client.get(ctx, path, &deployment); err != nil {
|
||||
return nil, fmt.Errorf("failed to get deployment by name: %w", err)
|
||||
}
|
||||
return &deployment, nil
|
||||
}
|
||||
|
||||
// Update updates a deployment.
|
||||
func (s *DeploymentsService) Update(ctx context.Context, id uuid.UUID, req *models.DeploymentUpdate) (*models.Deployment, error) {
|
||||
var deployment models.Deployment
|
||||
path := joinPath("/deployments", id.String())
|
||||
if err := s.client.patch(ctx, path, req, &deployment); err != nil {
|
||||
return nil, fmt.Errorf("failed to update deployment: %w", err)
|
||||
}
|
||||
return &deployment, nil
|
||||
}
|
||||
|
||||
// Delete deletes a deployment by ID.
|
||||
func (s *DeploymentsService) Delete(ctx context.Context, id uuid.UUID) error {
|
||||
path := joinPath("/deployments", id.String())
|
||||
if err := s.client.delete(ctx, path); err != nil {
|
||||
return fmt.Errorf("failed to delete deployment: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pause pauses a deployment's schedule.
|
||||
func (s *DeploymentsService) Pause(ctx context.Context, id uuid.UUID) error {
|
||||
path := joinPath("/deployments", id.String(), "pause")
|
||||
if err := s.client.post(ctx, path, nil, nil); err != nil {
|
||||
return fmt.Errorf("failed to pause deployment: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resume resumes a deployment's schedule.
|
||||
func (s *DeploymentsService) Resume(ctx context.Context, id uuid.UUID) error {
|
||||
path := joinPath("/deployments", id.String(), "resume")
|
||||
if err := s.client.post(ctx, path, nil, nil); err != nil {
|
||||
return fmt.Errorf("failed to resume deployment: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateFlowRun creates a flow run from a deployment.
|
||||
func (s *DeploymentsService) CreateFlowRun(ctx context.Context, id uuid.UUID, params map[string]interface{}) (*models.FlowRun, error) {
|
||||
var flowRun models.FlowRun
|
||||
path := joinPath("/deployments", id.String(), "create_flow_run")
|
||||
|
||||
req := struct {
|
||||
Parameters map[string]interface{} `json:"parameters,omitempty"`
|
||||
}{
|
||||
Parameters: params,
|
||||
}
|
||||
|
||||
if err := s.client.post(ctx, path, req, &flowRun); err != nil {
|
||||
return nil, fmt.Errorf("failed to create flow run from deployment: %w", err)
|
||||
}
|
||||
return &flowRun, nil
|
||||
}
|
||||
|
||||
// TaskRunsService handles operations related to task runs.
|
||||
type TaskRunsService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// Create creates a new task run.
|
||||
func (t *TaskRunsService) Create(ctx context.Context, req *models.TaskRunCreate) (*models.TaskRun, error) {
|
||||
var taskRun models.TaskRun
|
||||
if err := t.client.post(ctx, "/task_runs/", req, &taskRun); err != nil {
|
||||
return nil, fmt.Errorf("failed to create task run: %w", err)
|
||||
}
|
||||
return &taskRun, nil
|
||||
}
|
||||
|
||||
// Get retrieves a task run by ID.
|
||||
func (t *TaskRunsService) Get(ctx context.Context, id uuid.UUID) (*models.TaskRun, error) {
|
||||
var taskRun models.TaskRun
|
||||
path := joinPath("/task_runs", id.String())
|
||||
if err := t.client.get(ctx, path, &taskRun); err != nil {
|
||||
return nil, fmt.Errorf("failed to get task run: %w", err)
|
||||
}
|
||||
return &taskRun, nil
|
||||
}
|
||||
|
||||
// Delete deletes a task run by ID.
|
||||
func (t *TaskRunsService) Delete(ctx context.Context, id uuid.UUID) error {
|
||||
path := joinPath("/task_runs", id.String())
|
||||
if err := t.client.delete(ctx, path); err != nil {
|
||||
return fmt.Errorf("failed to delete task run: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetState sets the state of a task run.
|
||||
func (t *TaskRunsService) SetState(ctx context.Context, id uuid.UUID, state *models.StateCreate) (*models.TaskRun, error) {
|
||||
var taskRun models.TaskRun
|
||||
path := joinPath("/task_runs", id.String(), "set_state")
|
||||
|
||||
req := struct {
|
||||
State *models.StateCreate `json:"state"`
|
||||
}{
|
||||
State: state,
|
||||
}
|
||||
|
||||
if err := t.client.post(ctx, path, req, &taskRun); err != nil {
|
||||
return nil, fmt.Errorf("failed to set task run state: %w", err)
|
||||
}
|
||||
return &taskRun, nil
|
||||
}
|
||||
|
||||
// WorkPoolsService handles operations related to work pools.
|
||||
type WorkPoolsService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// Get retrieves a work pool by name.
|
||||
func (w *WorkPoolsService) Get(ctx context.Context, name string) (*models.WorkPool, error) {
|
||||
var workPool models.WorkPool
|
||||
path := joinPath("/work_pools", name)
|
||||
if err := w.client.get(ctx, path, &workPool); err != nil {
|
||||
return nil, fmt.Errorf("failed to get work pool: %w", err)
|
||||
}
|
||||
return &workPool, nil
|
||||
}
|
||||
|
||||
// WorkQueuesService handles operations related to work queues.
|
||||
type WorkQueuesService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// Get retrieves a work queue by ID.
|
||||
func (w *WorkQueuesService) Get(ctx context.Context, id uuid.UUID) (*models.WorkQueue, error) {
|
||||
var workQueue models.WorkQueue
|
||||
path := joinPath("/work_queues", id.String())
|
||||
if err := w.client.get(ctx, path, &workQueue); err != nil {
|
||||
return nil, fmt.Errorf("failed to get work queue: %w", err)
|
||||
}
|
||||
return &workQueue, nil
|
||||
}
|
||||
|
||||
// VariablesService handles operations related to variables.
|
||||
type VariablesService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// Create creates a new variable.
|
||||
func (v *VariablesService) Create(ctx context.Context, req *models.VariableCreate) (*models.Variable, error) {
|
||||
var variable models.Variable
|
||||
if err := v.client.post(ctx, "/variables/", req, &variable); err != nil {
|
||||
return nil, fmt.Errorf("failed to create variable: %w", err)
|
||||
}
|
||||
return &variable, nil
|
||||
}
|
||||
|
||||
// Get retrieves a variable by ID.
|
||||
func (v *VariablesService) Get(ctx context.Context, id uuid.UUID) (*models.Variable, error) {
|
||||
var variable models.Variable
|
||||
path := joinPath("/variables", id.String())
|
||||
if err := v.client.get(ctx, path, &variable); err != nil {
|
||||
return nil, fmt.Errorf("failed to get variable: %w", err)
|
||||
}
|
||||
return &variable, nil
|
||||
}
|
||||
|
||||
// GetByName retrieves a variable by name.
|
||||
func (v *VariablesService) GetByName(ctx context.Context, name string) (*models.Variable, error) {
|
||||
var variable models.Variable
|
||||
path := joinPath("/variables/name", name)
|
||||
if err := v.client.get(ctx, path, &variable); err != nil {
|
||||
return nil, fmt.Errorf("failed to get variable by name: %w", err)
|
||||
}
|
||||
return &variable, nil
|
||||
}
|
||||
|
||||
// Update updates a variable.
|
||||
func (v *VariablesService) Update(ctx context.Context, id uuid.UUID, req *models.VariableUpdate) (*models.Variable, error) {
|
||||
var variable models.Variable
|
||||
path := joinPath("/variables", id.String())
|
||||
if err := v.client.patch(ctx, path, req, &variable); err != nil {
|
||||
return nil, fmt.Errorf("failed to update variable: %w", err)
|
||||
}
|
||||
return &variable, nil
|
||||
}
|
||||
|
||||
// Delete deletes a variable by ID.
|
||||
func (v *VariablesService) Delete(ctx context.Context, id uuid.UUID) error {
|
||||
path := joinPath("/variables", id.String())
|
||||
if err := v.client.delete(ctx, path); err != nil {
|
||||
return fmt.Errorf("failed to delete variable: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogsService handles operations related to logs.
|
||||
type LogsService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// Create creates new log entries.
|
||||
func (l *LogsService) Create(ctx context.Context, logs []*models.LogCreate) error {
|
||||
if err := l.client.post(ctx, "/logs/", logs, nil); err != nil {
|
||||
return fmt.Errorf("failed to create logs: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List retrieves logs with filtering.
|
||||
func (l *LogsService) List(ctx context.Context, filter interface{}, offset, limit int) (*pagination.PaginatedResponse[models.Log], error) {
|
||||
type request struct {
|
||||
Filter interface{} `json:"filter,omitempty"`
|
||||
Offset int `json:"offset"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
req := request{
|
||||
Filter: filter,
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
}
|
||||
|
||||
type response struct {
|
||||
Results []models.Log `json:"results"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
var resp response
|
||||
if err := l.client.post(ctx, "/logs/filter", req, &resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to list logs: %w", err)
|
||||
}
|
||||
|
||||
return &pagination.PaginatedResponse[models.Log]{
|
||||
Results: resp.Results,
|
||||
Count: resp.Count,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
HasMore: offset+len(resp.Results) < resp.Count,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AdminService handles administrative operations.
|
||||
type AdminService struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
// Health checks the health of the Prefect server.
|
||||
func (a *AdminService) Health(ctx context.Context) error {
|
||||
if err := a.client.get(ctx, "/health", nil); err != nil {
|
||||
return fmt.Errorf("health check failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Version retrieves the server version.
|
||||
func (a *AdminService) Version(ctx context.Context) (string, error) {
|
||||
var result struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
if err := a.client.get(ctx, "/version", &result); err != nil {
|
||||
return "", fmt.Errorf("failed to get version: %w", err)
|
||||
}
|
||||
return result.Version, nil
|
||||
}
|
||||
Reference in New Issue
Block a user