From 9673251f9ea740ea305f5f57d7061f4c9a957bd0 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Thu, 11 Jan 2024 20:48:41 +0800 Subject: [PATCH 01/12] v1 Signed-off-by: Future Outlier --- .../go/tasks/pluginmachinery/k8s/client.go | 2 +- .../go/tasks/plugins/webapi/agent/client.go | 43 +++++++++++++++++++ .../tasks/plugins/webapi/agent/client_test.go | 14 ++++++ .../plugins/webapi/agent/integration_test.go | 12 ++++-- .../go/tasks/plugins/webapi/agent/plugin.go | 39 ++++------------- .../tasks/plugins/webapi/agent/plugin_test.go | 16 ++++--- 6 files changed, 87 insertions(+), 39 deletions(-) create mode 100644 flyteplugins/go/tasks/plugins/webapi/agent/client.go create mode 100644 flyteplugins/go/tasks/plugins/webapi/agent/client_test.go diff --git a/flyteplugins/go/tasks/pluginmachinery/k8s/client.go b/flyteplugins/go/tasks/pluginmachinery/k8s/client.go index f14ae2c8a0..0ab46081e9 100644 --- a/flyteplugins/go/tasks/pluginmachinery/k8s/client.go +++ b/flyteplugins/go/tasks/pluginmachinery/k8s/client.go @@ -69,7 +69,7 @@ func NewKubeClient(config *rest.Config, options Options) (core.KubeClient, error if options.ClientOptions == nil { options.ClientOptions = &client.Options{ HTTPClient: httpClient, - Mapper: mapper, + Mapper: mapper, } } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go new file mode 100644 index 0000000000..139df552b7 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -0,0 +1,43 @@ +package agent + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" +) + +type GetAgentClientFunc func(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) +type GetAgentMetadataClientFunc func(ctx context.Context, agent *Agent, connCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) + +// Clientset contains the clients exposed to communicate with various agent services. +type ClientFuncSet struct { + getAgentClient GetAgentClientFunc + getAgentMetadataClient GetAgentMetadataClientFunc +} + +func getAgentClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { + conn, err := getGrpcConnection(ctx, agent, connectionCache) + if err != nil { + return nil, err + } + + return service.NewAsyncAgentServiceClient(conn), nil +} + +func getAgentMetadataClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) { + conn, err := getGrpcConnection(ctx, agent, connectionCache) + if err != nil { + return nil, err + } + + return service.NewAgentMetadataServiceClient(conn), nil +} + +func initializeClientFunc() *ClientFuncSet { + return &ClientFuncSet{ + getAgentClient: getAgentClientFunc, + getAgentMetadataClient: getAgentMetadataClientFunc, + } +} diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go new file mode 100644 index 0000000000..5dfbe2f521 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -0,0 +1,14 @@ +package agent + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestInitializeClientFunc(t *testing.T) { + cs := initializeClientFunc() + assert.NotNil(t, cs) + assert.NotNil(t, cs.getAgentClient) + assert.NotNil(t, cs.getAgentMetadataClient) +} diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index 827af0d907..573145d29e 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -171,7 +171,9 @@ func TestEndToEnd(t *testing.T) { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - getClient: mockGetBadAsyncClientFunc, + cs: &ClientFuncSet{ + getAgentClient: mockGetBadAsyncClientFunc, + }, }, }, nil } @@ -311,7 +313,9 @@ func newMockAgentPlugin() webapi.PluginEntry { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - getClient: mockAsyncTaskClientFunc, + cs: &ClientFuncSet{ + getAgentClient: mockAsyncTaskClientFunc, + }, }, }, nil }, @@ -327,7 +331,9 @@ func newMockSyncAgentPlugin() webapi.PluginEntry { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - getClient: mockSyncTaskClientFunc, + cs: &ClientFuncSet{ + getAgentClient: mockSyncTaskClientFunc, + }, }, }, nil }, diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index b20bd62d7a..9badd074eb 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -17,7 +17,6 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" pluginErrors "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" @@ -30,13 +29,10 @@ import ( "github.com/flyteorg/flyte/flytestdlib/promutils" ) -type GetClientFunc func(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) -type GetAgentMetadataClientFunc func(ctx context.Context, agent *Agent, connCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) - type Plugin struct { metricScope promutils.Scope cfg *Config - getClient GetClientFunc + cs *ClientFuncSet connectionCache map[*Agent]*grpc.ClientConn agentRegistry map[string]*Agent // map[taskType] => Agent } @@ -96,7 +92,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR agent := getFinalAgent(taskTemplate.Type, p.cfg, p.agentRegistry) - client, err := p.getClient(ctx, agent, p.connectionCache) + client, err := p.cs.getAgentClient(ctx, agent, p.connectionCache) if err != nil { return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err) } @@ -145,7 +141,7 @@ func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest weba return nil, fmt.Errorf("failed to find agent with error: %v", err) } - client, err := p.getClient(ctx, agent, p.connectionCache) + client, err := p.cs.getAgentClient(ctx, agent, p.connectionCache) if err != nil { return nil, fmt.Errorf("failed to connect to agent with error: %v", err) } @@ -174,7 +170,7 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error agent := getFinalAgent(metadata.TaskType, p.cfg, p.agentRegistry) - client, err := p.getClient(ctx, agent, p.connectionCache) + client, err := p.cs.getAgentClient(ctx, agent, p.connectionCache) if err != nil { return fmt.Errorf("failed to connect to agent with error: %v", err) } @@ -287,24 +283,6 @@ func getGrpcConnection(ctx context.Context, agent *Agent, connectionCache map[*A return conn, nil } -func getClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - conn, err := getGrpcConnection(ctx, agent, connectionCache) - if err != nil { - return nil, err - } - - return service.NewAsyncAgentServiceClient(conn), nil -} - -func getAgentMetadataClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) { - conn, err := getGrpcConnection(ctx, agent, connectionCache) - if err != nil { - return nil, err - } - - return service.NewAgentMetadataServiceClient(conn), nil -} - func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata) admin.TaskExecutionMetadata { taskExecutionID := taskExecutionMetadata.GetTaskExecutionID().GetID() return admin.TaskExecutionMetadata{ @@ -334,7 +312,7 @@ func getFinalContext(ctx context.Context, operation string, agent *Agent) (conte return context.WithTimeout(ctx, timeout) } -func initializeAgentRegistry(cfg *Config, connectionCache map[*Agent]*grpc.ClientConn, getAgentMetadataClientFunc GetAgentMetadataClientFunc) (map[string]*Agent, error) { +func initializeAgentRegistry(cfg *Config, connectionCache map[*Agent]*grpc.ClientConn, cs *ClientFuncSet) (map[string]*Agent, error) { agentRegistry := make(map[string]*Agent) var agentDeployments []*Agent @@ -348,7 +326,7 @@ func initializeAgentRegistry(cfg *Config, connectionCache map[*Agent]*grpc.Clien } agentDeployments = append(agentDeployments, maps.Values(cfg.Agents)...) for _, agentDeployment := range agentDeployments { - client, err := getAgentMetadataClientFunc(context.Background(), agentDeployment, connectionCache) + client, err := cs.getAgentMetadataClient(context.Background(), agentDeployment, connectionCache) if err != nil { return nil, fmt.Errorf("failed to connect to agent [%v] with error: [%v]", agentDeployment, err) } @@ -385,9 +363,10 @@ func initializeAgentRegistry(cfg *Config, connectionCache map[*Agent]*grpc.Clien } func newAgentPlugin() webapi.PluginEntry { + cs := initializeClientFunc() cfg := GetConfig() connectionCache := make(map[*Agent]*grpc.ClientConn) - agentRegistry, err := initializeAgentRegistry(cfg, connectionCache, getAgentMetadataClientFunc) + agentRegistry, err := initializeAgentRegistry(cfg, connectionCache, cs) if err != nil { // We should wait for all agents to be up and running before starting the server panic(err) @@ -403,7 +382,7 @@ func newAgentPlugin() webapi.PluginEntry { return &Plugin{ metricScope: iCtx.MetricsScope(), cfg: cfg, - getClient: getClientFunc, + cs: cs, connectionCache: connectionCache, agentRegistry: agentRegistry, }, nil diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index b9fd7e1b35..a90f265ae8 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -108,13 +108,15 @@ func TestPlugin(t *testing.T) { }) t.Run("test getClientFunc", func(t *testing.T) { - client, err := getClientFunc(context.Background(), &Agent{Endpoint: "localhost:80"}, map[*Agent]*grpc.ClientConn{}) + cs := initializeClientFunc() + client, err := cs.getAgentClient(context.Background(), &Agent{Endpoint: "localhost:80"}, map[*Agent]*grpc.ClientConn{}) assert.NoError(t, err) assert.NotNil(t, client) }) t.Run("test getClientFunc more config", func(t *testing.T) { - client, err := getClientFunc(context.Background(), &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[*Agent]*grpc.ClientConn{}) + cs := initializeClientFunc() + client, err := cs.getAgentClient(context.Background(), &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[*Agent]*grpc.ClientConn{}) assert.NoError(t, err) assert.NotNil(t, client) }) @@ -123,12 +125,13 @@ func TestPlugin(t *testing.T) { connectionCache := make(map[*Agent]*grpc.ClientConn) agent := &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"} - client, err := getClientFunc(context.Background(), agent, connectionCache) + cs := initializeClientFunc() + client, err := cs.getAgentClient(context.Background(), agent, connectionCache) assert.NoError(t, err) assert.NotNil(t, client) assert.NotNil(t, client, connectionCache[agent]) - cachedClient, err := getClientFunc(context.Background(), agent, connectionCache) + cachedClient, err := cs.getAgentClient(context.Background(), agent, connectionCache) assert.NoError(t, err) assert.NotNil(t, cachedClient) assert.Equal(t, client, cachedClient) @@ -238,11 +241,14 @@ func TestInitializeAgentRegistry(t *testing.T) { return mockClient, nil } + cs := initializeClientFunc() + cs.getAgentMetadataClient = getAgentMetadataClientFunc + cfg := defaultConfig cfg.Agents = map[string]*Agent{"custom_agent": {Endpoint: "localhost:80"}} cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"} connectionCache := make(map[*Agent]*grpc.ClientConn) - agentRegistry, err := initializeAgentRegistry(&cfg, connectionCache, getAgentMetadataClientFunc) + agentRegistry, err := initializeAgentRegistry(&cfg, connectionCache, cs) assert.NoError(t, err) // In golang, the order of keys in a map is random. So, we sort the keys before asserting. From d3b92204f5bb83ceb37ea60310da1c2fc0aa8197 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 24 Jan 2024 02:20:27 -0800 Subject: [PATCH 02/12] kevin wip Signed-off-by: Kevin Su --- .../go/tasks/plugins/webapi/agent/client.go | 154 +++++++++++++-- .../tasks/plugins/webapi/agent/client_test.go | 48 ++++- .../plugins/webapi/agent/integration_test.go | 8 +- .../go/tasks/plugins/webapi/agent/plugin.go | 181 ++---------------- .../tasks/plugins/webapi/agent/plugin_test.go | 72 ++----- 5 files changed, 214 insertions(+), 249 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 139df552b7..7277250e60 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -2,42 +2,158 @@ package agent import ( "context" + "crypto/x509" + "fmt" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyte/flytestdlib/config" + "github.com/flyteorg/flyte/flytestdlib/logger" + "golang.org/x/exp/maps" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/status" "google.golang.org/grpc" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" ) -type GetAgentClientFunc func(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) -type GetAgentMetadataClientFunc func(ctx context.Context, agent *Agent, connCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) - -// Clientset contains the clients exposed to communicate with various agent services. -type ClientFuncSet struct { - getAgentClient GetAgentClientFunc - getAgentMetadataClient GetAgentMetadataClientFunc +// ClientSet contains the clients exposed to communicate with various agent services. +type ClientSet struct { + agentClients map[string]service.AsyncAgentServiceClient // map[endpoint] => client + agentMetadataClients map[string]service.AgentMetadataServiceClient // map[endpoint] => client } -func getAgentClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - conn, err := getGrpcConnection(ctx, agent, connectionCache) +func getGrpcConnection(ctx context.Context, agent *Agent) (*grpc.ClientConn, error) { + var opts []grpc.DialOption + + if agent.Insecure { + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } else { + pool, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + + creds := credentials.NewClientTLSFromCert(pool, "") + opts = append(opts, grpc.WithTransportCredentials(creds)) + } + + if len(agent.DefaultServiceConfig) != 0 { + opts = append(opts, grpc.WithDefaultServiceConfig(agent.DefaultServiceConfig)) + } + + var err error + conn, err := grpc.Dial(agent.Endpoint, opts...) if err != nil { return nil, err } + defer func() { + if err != nil { + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", agent, cerr) + } + return + } + go func() { + <-ctx.Done() + if cerr := conn.Close(); cerr != nil { + grpclog.Infof("Failed to close conn to %s: %v", agent, cerr) + } + }() + }() - return service.NewAsyncAgentServiceClient(conn), nil + return conn, nil } -func getAgentMetadataClientFunc(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) { - conn, err := getGrpcConnection(ctx, agent, connectionCache) - if err != nil { - return nil, err +func getFinalTimeout(operation string, agent *Agent) config.Duration { + if t, exists := agent.Timeouts[operation]; exists { + return t + } + + return agent.DefaultTimeout +} + +func getFinalContext(ctx context.Context, operation string, agent *Agent) (context.Context, context.CancelFunc) { + timeout := getFinalTimeout(operation, agent).Duration + if timeout == 0 { + return ctx, func() {} + } + + return context.WithTimeout(ctx, timeout) +} + +func initializeAgentRegistry(cs *ClientSet) (map[string]*Agent, error) { + agentRegistry := make(map[string]*Agent) + cfg := GetConfig() + var agentDeployments []*Agent + + // Ensure that the old configuration is backward compatible + for taskType, agentID := range cfg.AgentForTaskTypes { + agentRegistry[taskType] = cfg.Agents[agentID] + } + + if len(cfg.DefaultAgent.Endpoint) != 0 { + agentDeployments = append(agentDeployments, &cfg.DefaultAgent) } + agentDeployments = append(agentDeployments, maps.Values(cfg.Agents)...) + for _, agentDeployment := range agentDeployments { + client := cs.agentMetadataClients[agentDeployment.Endpoint] + + finalCtx, cancel := getFinalContext(context.Background(), "ListAgents", agentDeployment) + defer cancel() + + res, err := client.ListAgents(finalCtx, &admin.ListAgentsRequest{}) + if err != nil { + grpcStatus, ok := status.FromError(err) + if grpcStatus.Code() == codes.Unimplemented { + // we should not panic here, as we want to continue to support old agent settings + logger.Infof(context.Background(), "list agent method not implemented for agent: [%v]", agentDeployment) + continue + } + + if !ok { + return nil, fmt.Errorf("failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment, err) + } - return service.NewAgentMetadataServiceClient(conn), nil + return nil, fmt.Errorf("failed to list agent: [%v] with error: [%v]", agentDeployment, err) + } + + agents := res.GetAgents() + for _, agent := range agents { + supportedTaskTypes := agent.SupportedTaskTypes + for _, supportedTaskType := range supportedTaskTypes { + agentRegistry[supportedTaskType] = agentDeployment + } + } + } + + return agentRegistry, nil } -func initializeClientFunc() *ClientFuncSet { - return &ClientFuncSet{ - getAgentClient: getAgentClientFunc, - getAgentMetadataClient: getAgentMetadataClientFunc, +func initializeClients(ctx context.Context) (*ClientSet, error) { + agentClients := make(map[string]service.AsyncAgentServiceClient) + agentMetadataClients := make(map[string]service.AgentMetadataServiceClient) + + var agentDeployments []*Agent + cfg := GetConfig() + + if len(cfg.DefaultAgent.Endpoint) != 0 { + agentDeployments = append(agentDeployments, &cfg.DefaultAgent) + } + agentDeployments = append(agentDeployments, maps.Values(cfg.Agents)...) + for _, agentDeployment := range agentDeployments { + conn, err := getGrpcConnection(ctx, agentDeployment) + if err != nil { + return nil, err + } + agentClients[agentDeployment.Endpoint] = service.NewAsyncAgentServiceClient(conn) + agentMetadataClients[agentDeployment.Endpoint] = service.NewAgentMetadataServiceClient(conn) } + + return &ClientSet{ + agentClients: agentClients, + agentMetadataClients: agentMetadataClients, + }, nil } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index 5dfbe2f521..11c235414c 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -1,14 +1,52 @@ package agent import ( - "testing" - + "context" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "testing" ) +func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { + mockMetadataServiceClient := new(agentMocks.AgentMetadataServiceClient) + mockRequest := &admin.ListAgentsRequest{} + mockResponse := &admin.ListAgentsResponse{ + Agents: []*admin.Agent{ + { + Name: "test-agent", + SupportedTaskTypes: []string{"task1", "task2", "task3"}, + }, + }, + } + + mockMetadataServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) + return mockMetadataServiceClient +} + +func getMockServiceClient() *agentMocks.AgentMetadataServiceClient { + mockMetadataServiceClient := new(agentMocks.AgentMetadataServiceClient) + mockRequest := &admin.ListAgentsRequest{} + mockResponse := &admin.ListAgentsResponse{ + Agents: []*admin.Agent{ + { + Name: "test-agent", + SupportedTaskTypes: []string{"task1", "task2", "task3"}, + }, + }, + } + + mockMetadataServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) + return mockMetadataServiceClient +} + func TestInitializeClientFunc(t *testing.T) { - cs := initializeClientFunc() + cfg := defaultConfig + ctx := context.Background() + err := SetConfig(&cfg) + assert.NoError(t, err) + cs, err := initializeClients(ctx) + assert.NoError(t, err) assert.NotNil(t, cs) - assert.NotNil(t, cs.getAgentClient) - assert.NotNil(t, cs.getAgentMetadataClient) } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index 573145d29e..68acd50c9a 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -171,8 +171,8 @@ func TestEndToEnd(t *testing.T) { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - cs: &ClientFuncSet{ - getAgentClient: mockGetBadAsyncClientFunc, + cs: &ClientSet{ + agentClients: mockGetBadAsyncClientFunc, }, }, }, nil @@ -331,8 +331,8 @@ func newMockSyncAgentPlugin() webapi.PluginEntry { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - cs: &ClientFuncSet{ - getAgentClient: mockSyncTaskClientFunc, + cs: &ClientSet{ + agentClients: mockSyncTaskClientFunc, }, }, }, nil diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 9badd074eb..bc2033a70a 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -2,19 +2,10 @@ package agent import ( "context" - "crypto/x509" "encoding/gob" "fmt" "time" - "golang.org/x/exp/maps" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/status" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" pluginErrors "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" @@ -24,17 +15,16 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi" - "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" + "golang.org/x/exp/maps" ) type Plugin struct { - metricScope promutils.Scope - cfg *Config - cs *ClientFuncSet - connectionCache map[*Agent]*grpc.ClientConn - agentRegistry map[string]*Agent // map[taskType] => Agent + metricScope promutils.Scope + cfg *Config + cs *ClientSet + agentRegistry map[string]*Agent // map[taskType] => Agent } type ResourceWrapper struct { @@ -92,11 +82,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR agent := getFinalAgent(taskTemplate.Type, p.cfg, p.agentRegistry) - client, err := p.cs.getAgentClient(ctx, agent, p.connectionCache) - if err != nil { - return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err) - } - + client := p.cs.agentClients[agent.Endpoint] finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent) defer cancel() @@ -135,17 +121,9 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - agent := getFinalAgent(metadata.TaskType, p.cfg, p.agentRegistry) - if err != nil { - return nil, fmt.Errorf("failed to find agent with error: %v", err) - } - - client, err := p.cs.getAgentClient(ctx, agent, p.connectionCache) - if err != nil { - return nil, fmt.Errorf("failed to connect to agent with error: %v", err) - } + client := p.cs.agentClients[agent.Endpoint] finalCtx, cancel := getFinalContext(ctx, "GetTask", agent) defer cancel() @@ -167,18 +145,13 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return nil } metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - agent := getFinalAgent(metadata.TaskType, p.cfg, p.agentRegistry) - client, err := p.cs.getAgentClient(ctx, agent, p.connectionCache) - if err != nil { - return fmt.Errorf("failed to connect to agent with error: %v", err) - } - + client := p.cs.agentClients[agent.Endpoint] finalCtx, cancel := getFinalContext(ctx, "DeleteTask", agent) defer cancel() - _, err = client.DeleteTask(finalCtx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) + _, err := client.DeleteTask(finalCtx, &admin.DeleteTaskRequest{TaskType: metadata.TaskType, ResourceMeta: metadata.AgentResourceMeta}) return err } @@ -236,53 +209,6 @@ func getFinalAgent(taskType string, cfg *Config, agentRegistry map[string]*Agent return &cfg.DefaultAgent } -func getGrpcConnection(ctx context.Context, agent *Agent, connectionCache map[*Agent]*grpc.ClientConn) (*grpc.ClientConn, error) { - conn, ok := connectionCache[agent] - if ok { - return conn, nil - } - var opts []grpc.DialOption - - if agent.Insecure { - opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) - } else { - pool, err := x509.SystemCertPool() - if err != nil { - return nil, err - } - - creds := credentials.NewClientTLSFromCert(pool, "") - opts = append(opts, grpc.WithTransportCredentials(creds)) - } - - if len(agent.DefaultServiceConfig) != 0 { - opts = append(opts, grpc.WithDefaultServiceConfig(agent.DefaultServiceConfig)) - } - - var err error - conn, err = grpc.Dial(agent.Endpoint, opts...) - if err != nil { - return nil, err - } - connectionCache[agent] = conn - defer func() { - if err != nil { - if cerr := conn.Close(); cerr != nil { - grpclog.Infof("Failed to close conn to %s: %v", agent, cerr) - } - return - } - go func() { - <-ctx.Done() - if cerr := conn.Close(); cerr != nil { - grpclog.Infof("Failed to close conn to %s: %v", agent, cerr) - } - }() - }() - - return conn, nil -} - func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata) admin.TaskExecutionMetadata { taskExecutionID := taskExecutionMetadata.GetTaskExecutionID().GetID() return admin.TaskExecutionMetadata{ @@ -295,83 +221,19 @@ func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata } } -func getFinalTimeout(operation string, agent *Agent) config.Duration { - if t, exists := agent.Timeouts[operation]; exists { - return t - } - - return agent.DefaultTimeout -} - -func getFinalContext(ctx context.Context, operation string, agent *Agent) (context.Context, context.CancelFunc) { - timeout := getFinalTimeout(operation, agent).Duration - if timeout == 0 { - return ctx, func() {} - } - - return context.WithTimeout(ctx, timeout) -} - -func initializeAgentRegistry(cfg *Config, connectionCache map[*Agent]*grpc.ClientConn, cs *ClientFuncSet) (map[string]*Agent, error) { - agentRegistry := make(map[string]*Agent) - var agentDeployments []*Agent - - // Ensure that the old configuration is backward compatible - for taskType, agentID := range cfg.AgentForTaskTypes { - agentRegistry[taskType] = cfg.Agents[agentID] - } - - if len(cfg.DefaultAgent.Endpoint) != 0 { - agentDeployments = append(agentDeployments, &cfg.DefaultAgent) - } - agentDeployments = append(agentDeployments, maps.Values(cfg.Agents)...) - for _, agentDeployment := range agentDeployments { - client, err := cs.getAgentMetadataClient(context.Background(), agentDeployment, connectionCache) - if err != nil { - return nil, fmt.Errorf("failed to connect to agent [%v] with error: [%v]", agentDeployment, err) - } - - finalCtx, cancel := getFinalContext(context.Background(), "ListAgents", agentDeployment) - defer cancel() - - res, err := client.ListAgents(finalCtx, &admin.ListAgentsRequest{}) - if err != nil { - grpcStatus, ok := status.FromError(err) - if grpcStatus.Code() == codes.Unimplemented { - // we should not panic here, as we want to continue to support old agent settings - logger.Infof(context.Background(), "list agent method not implemented for agent: [%v]", agentDeployment) - continue - } - - if !ok { - return nil, fmt.Errorf("failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment, err) - } - - return nil, fmt.Errorf("failed to list agent: [%v] with error: [%v]", agentDeployment, err) - } - - agents := res.GetAgents() - for _, agent := range agents { - supportedTaskTypes := agent.SupportedTaskTypes - for _, supportedTaskType := range supportedTaskTypes { - agentRegistry[supportedTaskType] = agentDeployment - } - } - } - - return agentRegistry, nil -} - func newAgentPlugin() webapi.PluginEntry { - cs := initializeClientFunc() - cfg := GetConfig() - connectionCache := make(map[*Agent]*grpc.ClientConn) - agentRegistry, err := initializeAgentRegistry(cfg, connectionCache, cs) + cs, err := initializeClients(context.Background()) if err != nil { // We should wait for all agents to be up and running before starting the server - panic(err) + panic(fmt.Sprintf("failed to initalize clients with error: %v", err)) } + agentRegistry, err := initializeAgentRegistry(cs) + if err != nil { + panic(fmt.Sprintf("failed to initalize agent registry with error: %v", err)) + } + + cfg := GetConfig() supportedTaskTypes := append(maps.Keys(agentRegistry), cfg.SupportedTaskTypes...) logger.Infof(context.Background(), "Agent supports task types: %v", supportedTaskTypes) @@ -380,11 +242,10 @@ func newAgentPlugin() webapi.PluginEntry { SupportedTaskTypes: supportedTaskTypes, PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { return &Plugin{ - metricScope: iCtx.MetricsScope(), - cfg: cfg, - cs: cs, - connectionCache: connectionCache, - agentRegistry: agentRegistry, + metricScope: iCtx.MetricsScope(), + cfg: cfg, + cs: cs, + agentRegistry: agentRegistry, }, nil }, } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index a90f265ae8..588b20d024 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -2,29 +2,26 @@ package agent import ( "context" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "sort" "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "golang.org/x/exp/maps" - "google.golang.org/grpc" - "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" ioMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" webapiPlugin "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi/mocks" - agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" "github.com/flyteorg/flyte/flyteplugins/tests" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/promutils" "github.com/flyteorg/flyte/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/exp/maps" ) func TestSyncTask(t *testing.T) { @@ -101,42 +98,6 @@ func TestPlugin(t *testing.T) { assert.Equal(t, agent.Endpoint, cfg.DefaultAgent.Endpoint) }) - t.Run("test getAgentMetadataClientFunc", func(t *testing.T) { - client, err := getAgentMetadataClientFunc(context.Background(), &Agent{Endpoint: "localhost:80"}, map[*Agent]*grpc.ClientConn{}) - assert.NoError(t, err) - assert.NotNil(t, client) - }) - - t.Run("test getClientFunc", func(t *testing.T) { - cs := initializeClientFunc() - client, err := cs.getAgentClient(context.Background(), &Agent{Endpoint: "localhost:80"}, map[*Agent]*grpc.ClientConn{}) - assert.NoError(t, err) - assert.NotNil(t, client) - }) - - t.Run("test getClientFunc more config", func(t *testing.T) { - cs := initializeClientFunc() - client, err := cs.getAgentClient(context.Background(), &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"}, map[*Agent]*grpc.ClientConn{}) - assert.NoError(t, err) - assert.NotNil(t, client) - }) - - t.Run("test getClientFunc cache hit", func(t *testing.T) { - connectionCache := make(map[*Agent]*grpc.ClientConn) - agent := &Agent{Endpoint: "localhost:80", Insecure: true, DefaultServiceConfig: "{\"loadBalancingConfig\": [{\"round_robin\":{}}]}"} - - cs := initializeClientFunc() - client, err := cs.getAgentClient(context.Background(), agent, connectionCache) - assert.NoError(t, err) - assert.NotNil(t, client) - assert.NotNil(t, client, connectionCache[agent]) - - cachedClient, err := cs.getAgentClient(context.Background(), agent, connectionCache) - assert.NoError(t, err) - assert.NotNil(t, cachedClient) - assert.Equal(t, client, cachedClient) - }) - t.Run("test getFinalTimeout", func(t *testing.T) { timeout := getFinalTimeout("CreateTask", &Agent{Endpoint: "localhost:8080", Timeouts: map[string]config.Duration{"CreateTask": {Duration: 1 * time.Millisecond}}}) assert.Equal(t, 1*time.Millisecond, timeout.Duration) @@ -225,30 +186,19 @@ func TestPlugin(t *testing.T) { } func TestInitializeAgentRegistry(t *testing.T) { - mockClient := new(agentMocks.AgentMetadataServiceClient) - mockRequest := &admin.ListAgentsRequest{} - mockResponse := &admin.ListAgentsResponse{ - Agents: []*admin.Agent{ - { - Name: "test-agent", - SupportedTaskTypes: []string{"task1", "task2", "task3"}, - }, - }, - } + agentClients := make(map[string]service.AsyncAgentServiceClient) + agentMetadataClients := make(map[string]service.AgentMetadataServiceClient) + agentMetadataClients["localhost:80"] = getMockMetadataServiceClient() - mockClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) - getAgentMetadataClientFunc := func(ctx context.Context, agent *Agent, connCache map[*Agent]*grpc.ClientConn) (service.AgentMetadataServiceClient, error) { - return mockClient, nil + cs := &ClientSet{ + agentClients: agentClients, + agentMetadataClients: agentMetadataClients, } - cs := initializeClientFunc() - cs.getAgentMetadataClient = getAgentMetadataClientFunc - cfg := defaultConfig cfg.Agents = map[string]*Agent{"custom_agent": {Endpoint: "localhost:80"}} cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"} - connectionCache := make(map[*Agent]*grpc.ClientConn) - agentRegistry, err := initializeAgentRegistry(&cfg, connectionCache, cs) + agentRegistry, err := initializeAgentRegistry(cs) assert.NoError(t, err) // In golang, the order of keys in a map is random. So, we sort the keys before asserting. From 48d78d98d2ae128323c219be5a93ffc8cc6ff972 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 24 Jan 2024 23:05:56 +0800 Subject: [PATCH 03/12] add mockery AsyncAgentClient Signed-off-by: Future-Outlier --- .../go/tasks/plugins/webapi/agent/client.go | 2 + .../tasks/plugins/webapi/agent/client_test.go | 16 +- .../plugins/webapi/agent/integration_test.go | 22 +-- .../agent/mocks/AsyncAgentServiceClient.go | 162 ++++++++++++++++++ .../go/tasks/plugins/webapi/agent/plugin.go | 4 +- .../tasks/plugins/webapi/agent/plugin_test.go | 12 +- 6 files changed, 194 insertions(+), 24 deletions(-) create mode 100644 flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 7277250e60..9cfcb2f4db 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -4,6 +4,7 @@ import ( "context" "crypto/x509" "fmt" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/logger" @@ -88,6 +89,7 @@ func initializeAgentRegistry(cs *ClientSet) (map[string]*Agent, error) { agentRegistry := make(map[string]*Agent) cfg := GetConfig() var agentDeployments []*Agent + fmt.Printf("@@@ cfg.AgentForTaskTypes: [%v]\n", cfg.AgentForTaskTypes) // Ensure that the old configuration is backward compatible for taskType, agentID := range cfg.AgentForTaskTypes { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index 11c235414c..3d992c6a85 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -2,11 +2,12 @@ package agent import ( "context" + "testing" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "testing" ) func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { @@ -25,8 +26,9 @@ func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { return mockMetadataServiceClient } -func getMockServiceClient() *agentMocks.AgentMetadataServiceClient { - mockMetadataServiceClient := new(agentMocks.AgentMetadataServiceClient) +// TODO, USE CREATE, GET DELETE FUNCTION TO MOCK THE OUTPUT +func getMockServiceClient() *agentMocks.AsyncAgentServiceClient { + mockServiceClient := new(agentMocks.AsyncAgentServiceClient) mockRequest := &admin.ListAgentsRequest{} mockResponse := &admin.ListAgentsResponse{ Agents: []*admin.Agent{ @@ -37,8 +39,12 @@ func getMockServiceClient() *agentMocks.AgentMetadataServiceClient { }, } - mockMetadataServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) - return mockMetadataServiceClient + mockServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) + return mockServiceClient +} + +func mockGetBadAsyncClientFunc() *agentMocks.AsyncAgentServiceClient { + return nil } func TestInitializeClientFunc(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index 68acd50c9a..29c2db78d2 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -108,9 +108,9 @@ func mockSyncTaskClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.Clie return &MockSyncTask{}, nil } -func mockGetBadAsyncClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - return nil, fmt.Errorf("error") -} +// func mockGetBadAsyncClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { +// return nil, fmt.Errorf("error") +// } func TestEndToEnd(t *testing.T) { iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { @@ -172,7 +172,9 @@ func TestEndToEnd(t *testing.T) { metricScope: iCtx.MetricsScope(), cfg: GetConfig(), cs: &ClientSet{ - agentClients: mockGetBadAsyncClientFunc, + agentClients: map[string]service.AsyncAgentServiceClient{ + "localhost:80": mockGetBadAsyncClientFunc(), + }, }, }, }, nil @@ -313,9 +315,9 @@ func newMockAgentPlugin() webapi.PluginEntry { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - cs: &ClientFuncSet{ - getAgentClient: mockAsyncTaskClientFunc, - }, + // cs: &ClientSet{ + // getAgentClient: mockAsyncTaskClientFunc, + // }, }, }, nil }, @@ -331,9 +333,9 @@ func newMockSyncAgentPlugin() webapi.PluginEntry { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - cs: &ClientSet{ - agentClients: mockSyncTaskClientFunc, - }, + // cs: &ClientSet{ + // agentClients: mockSyncTaskClientFunc, + // }, }, }, nil }, diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go b/flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go new file mode 100644 index 0000000000..4a2b2c25f3 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go @@ -0,0 +1,162 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + admin "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + + grpc "google.golang.org/grpc" + + mock "github.com/stretchr/testify/mock" +) + +// AsyncAgentServiceClient is an autogenerated mock type for the AsyncAgentServiceClient type +type AsyncAgentServiceClient struct { + mock.Mock +} + +type AsyncAgentServiceClient_CreateTask struct { + *mock.Call +} + +func (_m AsyncAgentServiceClient_CreateTask) Return(_a0 *admin.CreateTaskResponse, _a1 error) *AsyncAgentServiceClient_CreateTask { + return &AsyncAgentServiceClient_CreateTask{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *AsyncAgentServiceClient) OnCreateTask(ctx context.Context, in *admin.CreateTaskRequest, opts ...grpc.CallOption) *AsyncAgentServiceClient_CreateTask { + c_call := _m.On("CreateTask", ctx, in, opts) + return &AsyncAgentServiceClient_CreateTask{Call: c_call} +} + +func (_m *AsyncAgentServiceClient) OnCreateTaskMatch(matchers ...interface{}) *AsyncAgentServiceClient_CreateTask { + c_call := _m.On("CreateTask", matchers...) + return &AsyncAgentServiceClient_CreateTask{Call: c_call} +} + +// CreateTask provides a mock function with given fields: ctx, in, opts +func (_m *AsyncAgentServiceClient) CreateTask(ctx context.Context, in *admin.CreateTaskRequest, opts ...grpc.CallOption) (*admin.CreateTaskResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.CreateTaskResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.CreateTaskRequest, ...grpc.CallOption) *admin.CreateTaskResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.CreateTaskResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.CreateTaskRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type AsyncAgentServiceClient_DeleteTask struct { + *mock.Call +} + +func (_m AsyncAgentServiceClient_DeleteTask) Return(_a0 *admin.DeleteTaskResponse, _a1 error) *AsyncAgentServiceClient_DeleteTask { + return &AsyncAgentServiceClient_DeleteTask{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *AsyncAgentServiceClient) OnDeleteTask(ctx context.Context, in *admin.DeleteTaskRequest, opts ...grpc.CallOption) *AsyncAgentServiceClient_DeleteTask { + c_call := _m.On("DeleteTask", ctx, in, opts) + return &AsyncAgentServiceClient_DeleteTask{Call: c_call} +} + +func (_m *AsyncAgentServiceClient) OnDeleteTaskMatch(matchers ...interface{}) *AsyncAgentServiceClient_DeleteTask { + c_call := _m.On("DeleteTask", matchers...) + return &AsyncAgentServiceClient_DeleteTask{Call: c_call} +} + +// DeleteTask provides a mock function with given fields: ctx, in, opts +func (_m *AsyncAgentServiceClient) DeleteTask(ctx context.Context, in *admin.DeleteTaskRequest, opts ...grpc.CallOption) (*admin.DeleteTaskResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.DeleteTaskResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.DeleteTaskRequest, ...grpc.CallOption) *admin.DeleteTaskResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.DeleteTaskResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.DeleteTaskRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type AsyncAgentServiceClient_GetTask struct { + *mock.Call +} + +func (_m AsyncAgentServiceClient_GetTask) Return(_a0 *admin.GetTaskResponse, _a1 error) *AsyncAgentServiceClient_GetTask { + return &AsyncAgentServiceClient_GetTask{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *AsyncAgentServiceClient) OnGetTask(ctx context.Context, in *admin.GetTaskRequest, opts ...grpc.CallOption) *AsyncAgentServiceClient_GetTask { + c_call := _m.On("GetTask", ctx, in, opts) + return &AsyncAgentServiceClient_GetTask{Call: c_call} +} + +func (_m *AsyncAgentServiceClient) OnGetTaskMatch(matchers ...interface{}) *AsyncAgentServiceClient_GetTask { + c_call := _m.On("GetTask", matchers...) + return &AsyncAgentServiceClient_GetTask{Call: c_call} +} + +// GetTask provides a mock function with given fields: ctx, in, opts +func (_m *AsyncAgentServiceClient) GetTask(ctx context.Context, in *admin.GetTaskRequest, opts ...grpc.CallOption) (*admin.GetTaskResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.GetTaskResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.GetTaskRequest, ...grpc.CallOption) *admin.GetTaskResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.GetTaskResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.GetTaskRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index bc2033a70a..b99ee357af 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -225,12 +225,12 @@ func newAgentPlugin() webapi.PluginEntry { cs, err := initializeClients(context.Background()) if err != nil { // We should wait for all agents to be up and running before starting the server - panic(fmt.Sprintf("failed to initalize clients with error: %v", err)) + panic(fmt.Sprintf("failed to initialize clients with error: %v", err)) } agentRegistry, err := initializeAgentRegistry(cs) if err != nil { - panic(fmt.Sprintf("failed to initalize agent registry with error: %v", err)) + panic(fmt.Sprintf("failed to initialize agent registry with error: %v", err)) } cfg := GetConfig() diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 588b20d024..14b0c10c89 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -2,11 +2,12 @@ package agent import ( "context" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" - "sort" "testing" "time" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" + "golang.org/x/exp/maps" + "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" @@ -21,7 +22,6 @@ import ( "github.com/flyteorg/flyte/flytestdlib/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "golang.org/x/exp/maps" ) func TestSyncTask(t *testing.T) { @@ -188,6 +188,7 @@ func TestPlugin(t *testing.T) { func TestInitializeAgentRegistry(t *testing.T) { agentClients := make(map[string]service.AsyncAgentServiceClient) agentMetadataClients := make(map[string]service.AgentMetadataServiceClient) + agentClients["localhost:80"] = getMockServiceClient() agentMetadataClients["localhost:80"] = getMockMetadataServiceClient() cs := &ClientSet{ @@ -201,9 +202,6 @@ func TestInitializeAgentRegistry(t *testing.T) { agentRegistry, err := initializeAgentRegistry(cs) assert.NoError(t, err) - // In golang, the order of keys in a map is random. So, we sort the keys before asserting. agentRegistryKeys := maps.Keys(agentRegistry) - sort.Strings(agentRegistryKeys) - - assert.Equal(t, agentRegistryKeys, []string{"task1", "task2", "task3"}) + assert.Equal(t, agentRegistryKeys, []string{}) } From 0cbfd41488bdb14de0f45e6cf37a5ffe98215d6b Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 25 Jan 2024 06:18:45 +0800 Subject: [PATCH 04/12] improve error message Signed-off-by: Future-Outlier --- flyteplugins/go/tasks/plugins/webapi/agent/plugin.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index b99ee357af..ba4efd30e6 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -83,6 +83,10 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR agent := getFinalAgent(taskTemplate.Type, p.cfg, p.agentRegistry) client := p.cs.agentClients[agent.Endpoint] + if client == nil { + return nil, nil, fmt.Errorf("agent:[%v] is not connected, please check if the agent is up and running", agent) + } + finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent) defer cancel() From 376df9c1637bffefc9b1f32597360c4115d7a59c Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 25 Jan 2024 06:19:25 +0800 Subject: [PATCH 05/12] improve error message Signed-off-by: Future-Outlier --- flyteplugins/go/tasks/plugins/webapi/agent/plugin.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index ba4efd30e6..9ef59c9832 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -84,7 +84,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR client := p.cs.agentClients[agent.Endpoint] if client == nil { - return nil, nil, fmt.Errorf("agent:[%v] is not connected, please check if the agent is up and running", agent) + return nil, nil, fmt.Errorf("default agent:[%v] is not connected, please check if the default agent is up and running", agent) } finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent) From f6643ac0f43104c8e9d6e28da26b8772b17433c4 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 25 Jan 2024 06:20:22 +0800 Subject: [PATCH 06/12] improve error message Signed-off-by: Future-Outlier --- flyteplugins/go/tasks/plugins/webapi/agent/plugin.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 9ef59c9832..fdba7001db 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -84,7 +84,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR client := p.cs.agentClients[agent.Endpoint] if client == nil { - return nil, nil, fmt.Errorf("default agent:[%v] is not connected, please check if the default agent is up and running", agent) + return nil, nil, fmt.Errorf("default agent:[%v] is not connected, please check if it is up and running", agent) } finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent) From 1b1936e0bac54583212fffc161422aacde1a0f0a Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 25 Jan 2024 06:44:08 +0800 Subject: [PATCH 07/12] need to use mockery AsyncAgentClient FIrst Signed-off-by: Future-Outlier --- .../go/tasks/plugins/webapi/agent/client.go | 10 +++---- .../tasks/plugins/webapi/agent/client_test.go | 5 ++-- .../plugins/webapi/agent/integration_test.go | 29 +++++++++++-------- .../go/tasks/plugins/webapi/agent/plugin.go | 3 +- .../tasks/plugins/webapi/agent/plugin_test.go | 6 ++-- 5 files changed, 29 insertions(+), 24 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 9cfcb2f4db..b118f64596 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -5,19 +5,18 @@ import ( "crypto/x509" "fmt" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" - "github.com/flyteorg/flyte/flytestdlib/config" - "github.com/flyteorg/flyte/flytestdlib/logger" "golang.org/x/exp/maps" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/status" - "google.golang.org/grpc" - + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flyte/flytestdlib/config" + "github.com/flyteorg/flyte/flytestdlib/logger" ) // ClientSet contains the clients exposed to communicate with various agent services. @@ -89,7 +88,6 @@ func initializeAgentRegistry(cs *ClientSet) (map[string]*Agent, error) { agentRegistry := make(map[string]*Agent) cfg := GetConfig() var agentDeployments []*Agent - fmt.Printf("@@@ cfg.AgentForTaskTypes: [%v]\n", cfg.AgentForTaskTypes) // Ensure that the old configuration is backward compatible for taskType, agentID := range cfg.AgentForTaskTypes { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index 3d992c6a85..fa43901166 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -4,10 +4,11 @@ import ( "context" "testing" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" - agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" + agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" ) func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index 29c2db78d2..e5d3d90cc0 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -100,13 +100,13 @@ func (m *MockSyncTask) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, return &admin.DeleteTaskResponse{}, nil } -func mockAsyncTaskClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - return &MockAsyncTask{}, nil -} +// func mockAsyncTaskClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { +// return &MockAsyncTask{}, nil +// } -func mockSyncTaskClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { - return &MockSyncTask{}, nil -} +// func mockSyncTaskClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { +// return &MockSyncTask{}, nil +// } // func mockGetBadAsyncClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { // return nil, fmt.Errorf("error") @@ -307,6 +307,7 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext { } func newMockAgentPlugin() webapi.PluginEntry { + return webapi.PluginEntry{ ID: "agent-service", SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark_job", "api_task"}, @@ -315,9 +316,11 @@ func newMockAgentPlugin() webapi.PluginEntry { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - // cs: &ClientSet{ - // getAgentClient: mockAsyncTaskClientFunc, - // }, + cs: &ClientSet{ + agentClients: map[string]service.AsyncAgentServiceClient{ + "": &MockAsyncTask{}, + }, + }, }, }, nil }, @@ -333,9 +336,11 @@ func newMockSyncAgentPlugin() webapi.PluginEntry { Plugin{ metricScope: iCtx.MetricsScope(), cfg: GetConfig(), - // cs: &ClientSet{ - // agentClients: mockSyncTaskClientFunc, - // }, + cs: &ClientSet{ + agentClients: map[string]service.AsyncAgentServiceClient{ + "": &MockSyncTask{}, + }, + }, }, }, nil }, diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index fdba7001db..5b2f0d4f51 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -6,6 +6,8 @@ import ( "fmt" "time" + "golang.org/x/exp/maps" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" pluginErrors "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" @@ -17,7 +19,6 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi" "github.com/flyteorg/flyte/flytestdlib/logger" "github.com/flyteorg/flyte/flytestdlib/promutils" - "golang.org/x/exp/maps" ) type Plugin struct { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 14b0c10c89..7e21e5a92a 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -5,12 +5,14 @@ import ( "testing" "time" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "golang.org/x/exp/maps" "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" @@ -20,8 +22,6 @@ import ( "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/promutils" "github.com/flyteorg/flyte/flytestdlib/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" ) func TestSyncTask(t *testing.T) { From cd533da00cdc47afc27b039db4c3172e1ad4fd90 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Thu, 25 Jan 2024 09:53:59 +0800 Subject: [PATCH 08/12] set config TestInitializeAgentRegistry Signed-off-by: Future-Outlier --- .../tasks/plugins/webapi/agent/client_test.go | 17 ----------------- .../plugins/webapi/agent/integration_test.go | 12 ------------ .../tasks/plugins/webapi/agent/plugin_test.go | 10 ++++++++-- 3 files changed, 8 insertions(+), 31 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index fa43901166..a75afb157d 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -27,23 +27,6 @@ func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { return mockMetadataServiceClient } -// TODO, USE CREATE, GET DELETE FUNCTION TO MOCK THE OUTPUT -func getMockServiceClient() *agentMocks.AsyncAgentServiceClient { - mockServiceClient := new(agentMocks.AsyncAgentServiceClient) - mockRequest := &admin.ListAgentsRequest{} - mockResponse := &admin.ListAgentsResponse{ - Agents: []*admin.Agent{ - { - Name: "test-agent", - SupportedTaskTypes: []string{"task1", "task2", "task3"}, - }, - }, - } - - mockServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) - return mockServiceClient -} - func mockGetBadAsyncClientFunc() *agentMocks.AsyncAgentServiceClient { return nil } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index e5d3d90cc0..5943b36c78 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -100,18 +100,6 @@ func (m *MockSyncTask) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, return &admin.DeleteTaskResponse{}, nil } -// func mockAsyncTaskClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { -// return &MockAsyncTask{}, nil -// } - -// func mockSyncTaskClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { -// return &MockSyncTask{}, nil -// } - -// func mockGetBadAsyncClientFunc(_ context.Context, _ *Agent, _ map[*Agent]*grpc.ClientConn) (service.AsyncAgentServiceClient, error) { -// return nil, fmt.Errorf("error") -// } - func TestEndToEnd(t *testing.T) { iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { return nil diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 7e21e5a92a..2ccd56a24b 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -2,6 +2,7 @@ package agent import ( "context" + "sort" "testing" "time" @@ -188,7 +189,7 @@ func TestPlugin(t *testing.T) { func TestInitializeAgentRegistry(t *testing.T) { agentClients := make(map[string]service.AsyncAgentServiceClient) agentMetadataClients := make(map[string]service.AgentMetadataServiceClient) - agentClients["localhost:80"] = getMockServiceClient() + agentClients["localhost:80"] = &MockAsyncTask{} agentMetadataClients["localhost:80"] = getMockMetadataServiceClient() cs := &ClientSet{ @@ -199,9 +200,14 @@ func TestInitializeAgentRegistry(t *testing.T) { cfg := defaultConfig cfg.Agents = map[string]*Agent{"custom_agent": {Endpoint: "localhost:80"}} cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"} + err := SetConfig(&cfg) + assert.NoError(t, err) agentRegistry, err := initializeAgentRegistry(cs) assert.NoError(t, err) + // In golang, the order of keys in a map is random. So, we sort the keys before asserting. agentRegistryKeys := maps.Keys(agentRegistry) - assert.Equal(t, agentRegistryKeys, []string{}) + sort.Strings(agentRegistryKeys) + + assert.Equal(t, agentRegistryKeys, []string{"task1", "task2", "task3"}) } From 80cf33ffabb4284068d417a5376dccd1f43f4ff7 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Fri, 26 Jan 2024 10:57:32 +0800 Subject: [PATCH 09/12] push change Signed-off-by: Future-Outlier --- .../tasks/plugins/webapi/agent/client_test.go | 18 -------- .../plugins/webapi/agent/integration_test.go | 46 +++++++++++++++++-- .../go/tasks/plugins/webapi/agent/plugin.go | 2 +- .../tasks/plugins/webapi/agent/plugin_test.go | 24 ++++++++-- 4 files changed, 63 insertions(+), 27 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index a75afb157d..235c3fe7e3 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -5,28 +5,10 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" ) -func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { - mockMetadataServiceClient := new(agentMocks.AgentMetadataServiceClient) - mockRequest := &admin.ListAgentsRequest{} - mockResponse := &admin.ListAgentsResponse{ - Agents: []*admin.Agent{ - { - Name: "test-agent", - SupportedTaskTypes: []string{"task1", "task2", "task3"}, - }, - }, - } - - mockMetadataServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) - return mockMetadataServiceClient -} - func mockGetBadAsyncClientFunc() *agentMocks.AsyncAgentServiceClient { return nil } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index 5943b36c78..d476fa2119 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -25,6 +25,7 @@ import ( pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" ioMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi" + agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" "github.com/flyteorg/flyte/flyteplugins/tests" "github.com/flyteorg/flyte/flytestdlib/contextutils" "github.com/flyteorg/flyte/flytestdlib/promutils" @@ -160,9 +161,8 @@ func TestEndToEnd(t *testing.T) { metricScope: iCtx.MetricsScope(), cfg: GetConfig(), cs: &ClientSet{ - agentClients: map[string]service.AsyncAgentServiceClient{ - "localhost:80": mockGetBadAsyncClientFunc(), - }, + agentClients: map[string]service.AsyncAgentServiceClient{}, + agentMetadataClients: map[string]service.AgentMetadataServiceClient{}, }, }, }, nil @@ -296,6 +296,44 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext { func newMockAgentPlugin() webapi.PluginEntry { + agentClient := new(agentMocks.AsyncAgentServiceClient) + // mockCreateRequest := &admin.CreateTaskRequest{} + // agentClient.On("CreateTask", mock.Anything, mockCreateRequest).Return( + // &admin.CreateTaskResponse{ + // Res: &admin.CreateTaskResponse_ResourceMeta{ + // ResourceMeta: []byte{1, 2, 3, 4}, + // }}, nil) + + agentClient.On("CreateTask", mock.Anything, mock.MatchedBy(func(req *admin.CreateTaskRequest) bool { + // Your custom logic to decide if the condition is met + expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"} + return slices.Equal(req.Template.GetContainer().Args, expectedArgs) + })).Run(func(args mock.Arguments) { + req := args.Get(1).(*admin.CreateTaskRequest) + // Extract the mock.Call object + call := args.Get(0).(mock.Call) + + if slices.Equal(req.Template.GetContainer().Args, []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"}) { + // If condition is met, return a specific response + call.Return(&admin.CreateTaskResponse{ + Res: &admin.CreateTaskResponse_ResourceMeta{ + ResourceMeta: []byte{1, 2, 3, 4}, + }, + }, nil) + } else { + // Else, return a different response or error + call.Return(nil, fmt.Errorf("unexpected arguments")) + } + }).Maybe() + + mockGetRequest := &admin.GetTaskRequest{} + agentClient.On("GetTask", mock.Anything, mockGetRequest).Return( + &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil) + + mockDeleteRequest := &admin.DeleteTaskRequest{} + agentClient.On("DeleteTask", mock.Anything, mockDeleteRequest).Return( + &admin.DeleteTaskResponse{}, nil) + return webapi.PluginEntry{ ID: "agent-service", SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark_job", "api_task"}, @@ -306,7 +344,7 @@ func newMockAgentPlugin() webapi.PluginEntry { cfg: GetConfig(), cs: &ClientSet{ agentClients: map[string]service.AsyncAgentServiceClient{ - "": &MockAsyncTask{}, + "": agentClient, }, }, }, diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 5b2f0d4f51..8f5ce243af 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -85,7 +85,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR client := p.cs.agentClients[agent.Endpoint] if client == nil { - return nil, nil, fmt.Errorf("default agent:[%v] is not connected, please check if it is up and running", agent) + return nil, nil, fmt.Errorf("default agent is not connected, please check if endpoint:[%v] is up and running", agent.Endpoint) } finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 2ccd56a24b..d17adf5180 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -6,10 +6,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "golang.org/x/exp/maps" - "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" @@ -19,10 +15,14 @@ import ( pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" ioMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" webapiPlugin "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi/mocks" + agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" "github.com/flyteorg/flyte/flyteplugins/tests" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/promutils" "github.com/flyteorg/flyte/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/exp/maps" ) func TestSyncTask(t *testing.T) { @@ -186,6 +186,22 @@ func TestPlugin(t *testing.T) { }) } +func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { + mockMetadataServiceClient := new(agentMocks.AgentMetadataServiceClient) + mockRequest := &admin.ListAgentsRequest{} + mockResponse := &admin.ListAgentsResponse{ + Agents: []*admin.Agent{ + { + Name: "test-agent", + SupportedTaskTypes: []string{"task1", "task2", "task3"}, + }, + }, + } + + mockMetadataServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) + return mockMetadataServiceClient +} + func TestInitializeAgentRegistry(t *testing.T) { agentClients := make(map[string]service.AsyncAgentServiceClient) agentMetadataClients := make(map[string]service.AgentMetadataServiceClient) From 3ac82cac1c656d1745587910bee82da1fcd8360e Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 31 Jan 2024 01:27:06 -0800 Subject: [PATCH 10/12] make generate Signed-off-by: Kevin Su --- .../agent/mocks/AsyncAgentServiceClient.go | 96 +++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go b/flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go index 4a2b2c25f3..f11ef1adfe 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/mocks/AsyncAgentServiceClient.go @@ -160,3 +160,99 @@ func (_m *AsyncAgentServiceClient) GetTask(ctx context.Context, in *admin.GetTas return r0, r1 } + +type AsyncAgentServiceClient_GetTaskLogs struct { + *mock.Call +} + +func (_m AsyncAgentServiceClient_GetTaskLogs) Return(_a0 *admin.GetTaskLogsResponse, _a1 error) *AsyncAgentServiceClient_GetTaskLogs { + return &AsyncAgentServiceClient_GetTaskLogs{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *AsyncAgentServiceClient) OnGetTaskLogs(ctx context.Context, in *admin.GetTaskLogsRequest, opts ...grpc.CallOption) *AsyncAgentServiceClient_GetTaskLogs { + c_call := _m.On("GetTaskLogs", ctx, in, opts) + return &AsyncAgentServiceClient_GetTaskLogs{Call: c_call} +} + +func (_m *AsyncAgentServiceClient) OnGetTaskLogsMatch(matchers ...interface{}) *AsyncAgentServiceClient_GetTaskLogs { + c_call := _m.On("GetTaskLogs", matchers...) + return &AsyncAgentServiceClient_GetTaskLogs{Call: c_call} +} + +// GetTaskLogs provides a mock function with given fields: ctx, in, opts +func (_m *AsyncAgentServiceClient) GetTaskLogs(ctx context.Context, in *admin.GetTaskLogsRequest, opts ...grpc.CallOption) (*admin.GetTaskLogsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.GetTaskLogsResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.GetTaskLogsRequest, ...grpc.CallOption) *admin.GetTaskLogsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.GetTaskLogsResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.GetTaskLogsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type AsyncAgentServiceClient_GetTaskMetrics struct { + *mock.Call +} + +func (_m AsyncAgentServiceClient_GetTaskMetrics) Return(_a0 *admin.GetTaskMetricsResponse, _a1 error) *AsyncAgentServiceClient_GetTaskMetrics { + return &AsyncAgentServiceClient_GetTaskMetrics{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *AsyncAgentServiceClient) OnGetTaskMetrics(ctx context.Context, in *admin.GetTaskMetricsRequest, opts ...grpc.CallOption) *AsyncAgentServiceClient_GetTaskMetrics { + c_call := _m.On("GetTaskMetrics", ctx, in, opts) + return &AsyncAgentServiceClient_GetTaskMetrics{Call: c_call} +} + +func (_m *AsyncAgentServiceClient) OnGetTaskMetricsMatch(matchers ...interface{}) *AsyncAgentServiceClient_GetTaskMetrics { + c_call := _m.On("GetTaskMetrics", matchers...) + return &AsyncAgentServiceClient_GetTaskMetrics{Call: c_call} +} + +// GetTaskMetrics provides a mock function with given fields: ctx, in, opts +func (_m *AsyncAgentServiceClient) GetTaskMetrics(ctx context.Context, in *admin.GetTaskMetricsRequest, opts ...grpc.CallOption) (*admin.GetTaskMetricsResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *admin.GetTaskMetricsResponse + if rf, ok := ret.Get(0).(func(context.Context, *admin.GetTaskMetricsRequest, ...grpc.CallOption) *admin.GetTaskMetricsResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*admin.GetTaskMetricsResponse) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *admin.GetTaskMetricsRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} From 3ea3f13d57a9b2fb230487733e86728a9ccfe279 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 31 Jan 2024 03:36:37 -0800 Subject: [PATCH 11/12] update tests Signed-off-by: Kevin Su --- .../tasks/plugins/webapi/agent/client_test.go | 8 +- .../plugins/webapi/agent/integration_test.go | 186 +++--------------- .../tasks/plugins/webapi/agent/plugin_test.go | 41 +--- 3 files changed, 35 insertions(+), 200 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index 235c3fe7e3..d68811d037 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -5,15 +5,9 @@ import ( "testing" "github.com/stretchr/testify/assert" - - agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" ) -func mockGetBadAsyncClientFunc() *agentMocks.AsyncAgentServiceClient { - return nil -} - -func TestInitializeClientFunc(t *testing.T) { +func TestInitializeClients(t *testing.T) { cfg := defaultConfig ctx := context.Background() err := SetConfig(&cfg) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index c1c7687737..998b0fcc14 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -10,7 +10,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "google.golang.org/grpc" "k8s.io/apimachinery/pkg/util/rand" "k8s.io/utils/strings/slices" @@ -34,89 +33,6 @@ import ( "github.com/flyteorg/flyte/flytestdlib/utils" ) -type MockPlugin struct { - Plugin -} - -type MockAsyncTask struct { -} - -func (m *MockAsyncTask) GetTaskMetrics(ctx context.Context, in *admin.GetTaskMetricsRequest, opts ...grpc.CallOption) (*admin.GetTaskMetricsResponse, error) { - panic("not implemented") -} - -func (m *MockAsyncTask) GetTaskLogs(ctx context.Context, in *admin.GetTaskLogsRequest, opts ...grpc.CallOption) (*admin.GetTaskLogsResponse, error) { - panic("not implemented") -} - -type MockSyncTask struct { -} - -func (m *MockSyncTask) GetTaskMetrics(ctx context.Context, in *admin.GetTaskMetricsRequest, opts ...grpc.CallOption) (*admin.GetTaskMetricsResponse, error) { - panic("not implemented") -} - -func (m *MockSyncTask) GetTaskLogs(ctx context.Context, in *admin.GetTaskLogsRequest, opts ...grpc.CallOption) (*admin.GetTaskLogsResponse, error) { - panic("not implemented") -} - -func (m *MockAsyncTask) CreateTask(_ context.Context, createTaskRequest *admin.CreateTaskRequest, _ ...grpc.CallOption) (*admin.CreateTaskResponse, error) { - expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"} - if slices.Equal(createTaskRequest.Template.GetContainer().Args, expectedArgs) { - return nil, fmt.Errorf("args not as expected") - } - return &admin.CreateTaskResponse{ - Res: &admin.CreateTaskResponse_ResourceMeta{ - ResourceMeta: []byte{1, 2, 3, 4}, - }}, nil -} - -func (m *MockAsyncTask) GetTask(_ context.Context, req *admin.GetTaskRequest, _ ...grpc.CallOption) (*admin.GetTaskResponse, error) { - if req.GetTaskType() == "bigquery_query_job_task" { - return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{ - Literals: map[string]*flyteIdlCore.Literal{ - "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), - }, - }}}, nil - } - return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil -} - -func (m *MockAsyncTask) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _ ...grpc.CallOption) (*admin.DeleteTaskResponse, error) { - return &admin.DeleteTaskResponse{}, nil -} - -func (m *MockSyncTask) CreateTask(_ context.Context, createTaskRequest *admin.CreateTaskRequest, _ ...grpc.CallOption) (*admin.CreateTaskResponse, error) { - return &admin.CreateTaskResponse{ - Res: &admin.CreateTaskResponse_Resource{ - Resource: &admin.Resource{ - State: admin.State_SUCCEEDED, - Outputs: &flyteIdlCore.LiteralMap{ - Literals: map[string]*flyteIdlCore.Literal{}, - }, - Message: "Sync task finished", - LogLinks: []*flyteIdlCore.TaskLog{{Uri: "http://localhost:3000/log", Name: "Log Link"}}, - }, - }, - }, nil - -} - -func (m *MockSyncTask) GetTask(_ context.Context, req *admin.GetTaskRequest, _ ...grpc.CallOption) (*admin.GetTaskResponse, error) { - if req.GetTaskType() == "fake_task" { - return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED, Outputs: &flyteIdlCore.LiteralMap{ - Literals: map[string]*flyteIdlCore.Literal{ - "arr": coreutils.MustMakeLiteral([]interface{}{[]interface{}{"a", "b"}, []interface{}{1, 2}}), - }, - }}}, nil - } - return &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil -} - -func (m *MockSyncTask) DeleteTask(_ context.Context, _ *admin.DeleteTaskRequest, _ ...grpc.CallOption) (*admin.DeleteTaskResponse, error) { - return &admin.DeleteTaskResponse{}, nil -} - func TestEndToEnd(t *testing.T) { iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { return nil @@ -126,6 +42,7 @@ func TestEndToEnd(t *testing.T) { cfg.WebAPI.ResourceQuotas = map[core.ResourceNamespace]int{} cfg.WebAPI.Caching.Workers = 1 cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + cfg.DefaultAgent.Endpoint = "localhost:8000" err := SetConfig(&cfg) assert.NoError(t, err) @@ -147,10 +64,10 @@ func TestEndToEnd(t *testing.T) { inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) template := flyteIdlCore.TaskTemplate{ - Type: "bigquery_query_job_task", + Type: "databricks", Custom: st, Target: &flyteIdlCore.TaskTemplate_Container{ - Container: &flyteIdlCore.Container{Args: []string{"pyflyte-fast-execute", "--output-prefix", "{{.outputPrefix}}"}}, + Container: &flyteIdlCore.Container{Args: []string{"pyflyte-fast-execute", "--output-prefix", "/tmp/123"}}, }, } basePrefix := storage.DataReference("fake://bucket/prefix/") @@ -163,23 +80,20 @@ func TestEndToEnd(t *testing.T) { phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) assert.Equal(t, true, phase.Phase().IsSuccess()) - template.Type = "spark_job" + template.Type = "spark" phase = tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, iter) assert.Equal(t, true, phase.Phase().IsSuccess()) - }) t.Run("failed to create a job", func(t *testing.T) { agentPlugin := newMockAgentPlugin() agentPlugin.PluginLoader = func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { - return &MockPlugin{ - Plugin{ - metricScope: iCtx.MetricsScope(), - cfg: GetConfig(), - cs: &ClientSet{ - agentClients: map[string]service.AsyncAgentServiceClient{}, - agentMetadataClients: map[string]service.AgentMetadataServiceClient{}, - }, + return Plugin{ + metricScope: iCtx.MetricsScope(), + cfg: GetConfig(), + cs: &ClientSet{ + agentClients: map[string]service.AsyncAgentServiceClient{}, + agentMetadataClients: map[string]service.AgentMetadataServiceClient{}, }, }, nil } @@ -313,75 +227,35 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext { func newMockAgentPlugin() webapi.PluginEntry { agentClient := new(agentMocks.AsyncAgentServiceClient) - // mockCreateRequest := &admin.CreateTaskRequest{} - // agentClient.On("CreateTask", mock.Anything, mockCreateRequest).Return( - // &admin.CreateTaskResponse{ - // Res: &admin.CreateTaskResponse_ResourceMeta{ - // ResourceMeta: []byte{1, 2, 3, 4}, - // }}, nil) - - agentClient.On("CreateTask", mock.Anything, mock.MatchedBy(func(req *admin.CreateTaskRequest) bool { - // Your custom logic to decide if the condition is met - expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"} - return slices.Equal(req.Template.GetContainer().Args, expectedArgs) - })).Run(func(args mock.Arguments) { - req := args.Get(1).(*admin.CreateTaskRequest) - // Extract the mock.Call object - call := args.Get(0).(mock.Call) - - if slices.Equal(req.Template.GetContainer().Args, []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"}) { - // If condition is met, return a specific response - call.Return(&admin.CreateTaskResponse{ - Res: &admin.CreateTaskResponse_ResourceMeta{ - ResourceMeta: []byte{1, 2, 3, 4}, - }, - }, nil) - } else { - // Else, return a different response or error - call.Return(nil, fmt.Errorf("unexpected arguments")) - } - }).Maybe() - mockGetRequest := &admin.GetTaskRequest{} - agentClient.On("GetTask", mock.Anything, mockGetRequest).Return( + mockCreateRequestMatcher := mock.MatchedBy(func(request *admin.CreateTaskRequest) bool { + expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "/tmp/123"} + return slices.Equal(request.Template.GetContainer().Args, expectedArgs) + }) + agentClient.On("CreateTask", mock.Anything, mockCreateRequestMatcher).Return(&admin.CreateTaskResponse{ + Res: &admin.CreateTaskResponse_ResourceMeta{ + ResourceMeta: []byte{1, 2, 3, 4}, + }}, nil) + + agentClient.On("GetTask", mock.Anything, mock.Anything).Return( &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil) - mockDeleteRequest := &admin.DeleteTaskRequest{} - agentClient.On("DeleteTask", mock.Anything, mockDeleteRequest).Return( + agentClient.On("DeleteTask", mock.Anything, mock.Anything).Return( &admin.DeleteTaskResponse{}, nil) - return webapi.PluginEntry{ - ID: "agent-service", - SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark_job", "api_task"}, - PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { - return &MockPlugin{ - Plugin{ - metricScope: iCtx.MetricsScope(), - cfg: GetConfig(), - cs: &ClientSet{ - agentClients: map[string]service.AsyncAgentServiceClient{ - "": agentClient, - }, - }, - }, - }, nil - }, - } -} + cfg := defaultConfig + cfg.DefaultAgent.Endpoint = "localhost:8000" -func newMockSyncAgentPlugin() webapi.PluginEntry { return webapi.PluginEntry{ ID: "agent-service", - SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark_job", "api_task"}, + SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark", "api_task"}, PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { - return &MockPlugin{ - Plugin{ - metricScope: iCtx.MetricsScope(), - cfg: GetConfig(), - cs: &ClientSet{ - agentClients: map[string]service.AsyncAgentServiceClient{ - "": &MockSyncTask{}, - }, + return Plugin{ + metricScope: iCtx.MetricsScope(), + cfg: &cfg, + cs: &ClientSet{ + agentClients: map[string]service.AsyncAgentServiceClient{ + "localhost:8000": agentClient, }, }, }, nil diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 8dd83f1a50..e66f46f1bc 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -6,55 +6,22 @@ import ( "testing" "time" - "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" flyteIdlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" - ioMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" webapiPlugin "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi/mocks" agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" - "github.com/flyteorg/flyte/flyteplugins/tests" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/promutils" - "github.com/flyteorg/flyte/flytestdlib/storage" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "golang.org/x/exp/maps" ) -func TestSyncTask(t *testing.T) { - tCtx := getTaskContext(t) - taskReader := new(pluginCoreMocks.TaskReader) - - template := flyteIdlCore.TaskTemplate{ - Type: "api_task", - } - - taskReader.On("Read", mock.Anything).Return(&template, nil) - - tCtx.OnTaskReader().Return(taskReader) - - agentPlugin := newMockSyncAgentPlugin() - pluginEntry := pluginmachinery.CreateRemotePlugin(agentPlugin) - plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("create_task_sync_test")) - assert.NoError(t, err) - - inputs, err := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) - assert.NoError(t, err) - basePrefix := storage.DataReference("fake://bucket/prefix/") - inputReader := &ioMocks.InputReader{} - inputReader.OnGetInputPrefixPath().Return(basePrefix) - inputReader.OnGetInputPath().Return(basePrefix + "/inputs.pb") - inputReader.OnGetMatch(mock.Anything).Return(inputs, nil) - tCtx.OnInputReader().Return(inputReader) - - phase := tests.RunPluginEndToEndTest(t, plugin, &template, inputs, nil, nil, nil) - assert.Equal(t, true, phase.Phase().IsSuccess()) -} +const defaultAgentEndpoint = "localhost:8000" func TestPlugin(t *testing.T) { fakeSetupContext := pluginCoreMocks.SetupContext{} @@ -318,8 +285,8 @@ func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { func TestInitializeAgentRegistry(t *testing.T) { agentClients := make(map[string]service.AsyncAgentServiceClient) agentMetadataClients := make(map[string]service.AgentMetadataServiceClient) - agentClients["localhost:80"] = &MockAsyncTask{} - agentMetadataClients["localhost:80"] = getMockMetadataServiceClient() + agentClients[defaultAgentEndpoint] = &agentMocks.AsyncAgentServiceClient{} + agentMetadataClients[defaultAgentEndpoint] = getMockMetadataServiceClient() cs := &ClientSet{ agentClients: agentClients, @@ -327,7 +294,7 @@ func TestInitializeAgentRegistry(t *testing.T) { } cfg := defaultConfig - cfg.Agents = map[string]*Agent{"custom_agent": {Endpoint: "localhost:80"}} + cfg.Agents = map[string]*Agent{"custom_agent": {Endpoint: defaultAgentEndpoint}} cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"} err := SetConfig(&cfg) assert.NoError(t, err) From 34a404966c0befb05a7f86d20daac7d76c6c285c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 31 Jan 2024 03:44:17 -0800 Subject: [PATCH 12/12] nit Signed-off-by: Kevin Su --- .../go/tasks/plugins/webapi/agent/integration_test.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index 998b0fcc14..fe3b45b881 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -64,7 +64,7 @@ func TestEndToEnd(t *testing.T) { inputs, _ := coreutils.MakeLiteralMap(map[string]interface{}{"x": 1}) template := flyteIdlCore.TaskTemplate{ - Type: "databricks", + Type: "spark", Custom: st, Target: &flyteIdlCore.TaskTemplate_Container{ Container: &flyteIdlCore.Container{Args: []string{"pyflyte-fast-execute", "--output-prefix", "/tmp/123"}}, @@ -237,7 +237,10 @@ func newMockAgentPlugin() webapi.PluginEntry { ResourceMeta: []byte{1, 2, 3, 4}, }}, nil) - agentClient.On("GetTask", mock.Anything, mock.Anything).Return( + mockGetRequestMatcher := mock.MatchedBy(func(request *admin.GetTaskRequest) bool { + return request.GetTaskType() == "spark" + }) + agentClient.On("GetTask", mock.Anything, mockGetRequestMatcher).Return( &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil) agentClient.On("DeleteTask", mock.Anything, mock.Anything).Return(