From 0b1955ce2179bb65e58d02dc38671165f804034d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 5 Mar 2024 17:51:41 -0800 Subject: [PATCH 01/21] wip Signed-off-by: Kevin Su --- .../go/tasks/plugins/webapi/agent/client.go | 27 +++++++------- .../tasks/plugins/webapi/agent/client_test.go | 3 +- .../go/tasks/plugins/webapi/agent/plugin.go | 35 ++++++++++--------- .../tasks/plugins/webapi/agent/plugin_test.go | 2 +- 4 files changed, 34 insertions(+), 33 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index b525acc5c3..106ddcde96 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -3,8 +3,6 @@ package agent import ( "context" "crypto/x509" - "fmt" - "golang.org/x/exp/maps" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -98,7 +96,7 @@ func getFinalContext(ctx context.Context, operation string, agent *Deployment) ( return context.WithTimeout(ctx, timeout) } -func initializeAgentRegistry(cs *ClientSet) (Registry, error) { +func createAgentRegistry(ctx context.Context, cs *ClientSet) Registry { agentRegistry := make(Registry) cfg := GetConfig() var agentDeployments []*Deployment @@ -116,7 +114,7 @@ func initializeAgentRegistry(cs *ClientSet) (Registry, error) { for _, agentDeployment := range agentDeployments { client := cs.agentMetadataClients[agentDeployment.Endpoint] - finalCtx, cancel := getFinalContext(context.Background(), "ListAgents", agentDeployment) + finalCtx, cancel := getFinalContext(ctx, "ListAgents", agentDeployment) defer cancel() res, err := client.ListAgents(finalCtx, &admin.ListAgentsRequest{}) @@ -124,15 +122,15 @@ func initializeAgentRegistry(cs *ClientSet) (Registry, error) { grpcStatus, ok := status.FromError(err) if grpcStatus.Code() == codes.Unimplemented { // we should not panic here, as we want to continue to support old agent settings - logger.Infof(context.Background(), "list agent method not implemented for agent: [%v]", agentDeployment) + logger.Infof(finalCtx, "list agent method not implemented for agent: [%v]", agentDeployment.Endpoint) continue } if !ok { - return nil, fmt.Errorf("failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment, err) + logger.Warningf(finalCtx, "failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment.Endpoint, err) } - return nil, fmt.Errorf("failed to list agent: [%v] with error: [%v]", agentDeployment, err) + logger.Warningf(finalCtx, "failed to list agent: [%v] with error: [%v]", agentDeployment.Endpoint, err) } for _, agent := range res.GetAgents() { @@ -147,15 +145,16 @@ func initializeAgentRegistry(cs *ClientSet) (Registry, error) { agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync} agentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent} } - logger.Infof(context.Background(), "[%v] is a sync agent: [%v]", agent.Name, agent.IsSync) - logger.Infof(context.Background(), "[%v] supports task category: [%v]", agent.Name, supportedTaskCategories) + //logger.Infof(finalCtx, "[%v] ", agent) + //logger.Infof(context.Background(), "[%v] is a sync agent: [%v]", agent.Name, agent.IsSync) + //logger.Infof(context.Background(), "[%v] supports task category: [%v]", agent.Name, supportedTaskCategories) } } - - return agentRegistry, nil + logger.Infof(ctx, "agentRegistry: [%v] ", agentRegistry) + return agentRegistry } -func initializeClients(ctx context.Context) (*ClientSet, error) { +func createAgentClientSets(ctx context.Context) *ClientSet { asyncAgentClients := make(map[string]service.AsyncAgentServiceClient) syncAgentClients := make(map[string]service.SyncAgentServiceClient) agentMetadataClients := make(map[string]service.AgentMetadataServiceClient) @@ -170,7 +169,7 @@ func initializeClients(ctx context.Context) (*ClientSet, error) { for _, agentService := range agentDeployments { conn, err := getGrpcConnection(ctx, agentService) if err != nil { - return nil, err + logger.Warningf(ctx, "failed to create connection to agent: [%v] with error: [%v]", agentService, err) } syncAgentClients[agentService.Endpoint] = service.NewSyncAgentServiceClient(conn) asyncAgentClients[agentService.Endpoint] = service.NewAsyncAgentServiceClient(conn) @@ -181,5 +180,5 @@ func initializeClients(ctx context.Context) (*ClientSet, error) { syncAgentClients: syncAgentClients, asyncAgentClients: asyncAgentClients, agentMetadataClients: agentMetadataClients, - }, nil + } } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index 4ad7f8cbaa..c4eadb25e6 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -20,8 +20,7 @@ func TestInitializeClients(t *testing.T) { ctx := context.Background() err := SetConfig(&cfg) assert.NoError(t, err) - cs, err := initializeClients(ctx) - assert.NoError(t, err) + cs := createAgentClientSets(ctx) assert.NotNil(t, cs) _, ok := cs.syncAgentClients["y"] assert.True(t, ok) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 11ef7871b3..fa8ecc7e43 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -5,6 +5,8 @@ import ( "encoding/gob" "fmt" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" + "k8s.io/apimachinery/pkg/util/wait" "time" "golang.org/x/exp/maps" @@ -12,7 +14,6 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" pluginErrors "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/template" flyteIO "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io" @@ -315,6 +316,15 @@ func (p Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (ser return client, nil } +func (p Plugin) watchAgents(ctx context.Context) { + go wait.Until(func() { + cs := createAgentClientSets(ctx) + agentRegistry := createAgentRegistry(ctx, cs) + p.agentRegistry = agentRegistry + + }, time.Duration(5)*time.Second, ctx.Done()) +} + func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *flyteIdl.LiteralMap) error { taskTemplate, err := taskCtx.TaskReader().Read(ctx) if err != nil { @@ -341,7 +351,6 @@ func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config, agentRegistry if agent, exists := agentRegistry[taskCategory.Name][taskCategory.Version]; exists { return agent.AgentDeployment, agent.IsSync } - return &cfg.DefaultAgent, false } @@ -358,18 +367,11 @@ func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata } func newAgentPlugin() webapi.PluginEntry { - cs, err := initializeClients(context.Background()) - if err != nil { - // We should wait for all agents to be up and running before starting the server - panic(fmt.Sprintf("failed to initialize clients with error: %v", err)) - } - - agentRegistry, err := initializeAgentRegistry(cs) - if err != nil { - panic(fmt.Sprintf("failed to initialize agent registry with error: %v", err)) - } - + ctx := context.Background() cfg := GetConfig() + + cs := createAgentClientSets(ctx) + agentRegistry := createAgentRegistry(ctx, cs) supportedTaskTypes := append(maps.Keys(agentRegistry), cfg.SupportedTaskTypes...) logger.Infof(context.Background(), "AgentDeployment service supports task types: %v", supportedTaskTypes) @@ -377,12 +379,14 @@ func newAgentPlugin() webapi.PluginEntry { ID: "agent-service", SupportedTaskTypes: supportedTaskTypes, PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { - return &Plugin{ + plugin := &Plugin{ metricScope: iCtx.MetricsScope(), cfg: cfg, cs: cs, agentRegistry: agentRegistry, - }, nil + } + plugin.watchAgents(ctx) + return plugin, nil }, } } @@ -390,6 +394,5 @@ func newAgentPlugin() webapi.PluginEntry { func RegisterAgentPlugin() { gob.Register(ResourceMetaWrapper{}) gob.Register(ResourceWrapper{}) - pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin()) } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 9fa36c5c42..47834a84c8 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -307,7 +307,7 @@ func TestInitializeAgentRegistry(t *testing.T) { cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"} err := SetConfig(&cfg) assert.NoError(t, err) - agentRegistry, err := initializeAgentRegistry(cs) + agentRegistry, err := createAgentRegistry(cs) assert.NoError(t, err) // In golang, the order of keys in a map is random. So, we sort the keys before asserting. From 0cffe52f0c15e303921e0e088801a6d701bae1c7 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 6 Mar 2024 13:23:59 -0800 Subject: [PATCH 02/21] Watch agent service Signed-off-by: Kevin Su --- .../go/tasks/plugins/webapi/agent/client.go | 14 +++++--- .../go/tasks/plugins/webapi/agent/config.go | 3 ++ .../go/tasks/plugins/webapi/agent/plugin.go | 35 ++++++++++--------- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 106ddcde96..8378778a0f 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -145,12 +145,18 @@ func createAgentRegistry(ctx context.Context, cs *ClientSet) Registry { agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync} agentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent} } - //logger.Infof(finalCtx, "[%v] ", agent) - //logger.Infof(context.Background(), "[%v] is a sync agent: [%v]", agent.Name, agent.IsSync) - //logger.Infof(context.Background(), "[%v] supports task category: [%v]", agent.Name, supportedTaskCategories) + } + // If the agent doesn't implement the metadata service, we construct the registry based on the configuration + for taskType, agentDeploymentID := range cfg.AgentForTaskTypes { + agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID] + if ok { + agent := &Agent{AgentDeployment: agentDeployment, IsSync: false} + agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent} + } } } - logger.Infof(ctx, "agentRegistry: [%v] ", agentRegistry) + supportedTaskTypes := append(maps.Keys(agentRegistry)) + logger.Debugf(context.Background(), "AgentDeployment service supports task types: %v", supportedTaskTypes) return agentRegistry } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/config.go b/flyteplugins/go/tasks/plugins/webapi/agent/config.go index 3f9fd354b6..5bd8e255ca 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/config.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/config.go @@ -47,6 +47,7 @@ var ( // AsyncPlugin should be registered to at least one task type. // Reference: https://github.com/flyteorg/flyte/blob/master/flyteplugins/go/tasks/pluginmachinery/registry.go#L27 SupportedTaskTypes: []string{"task_type_1", "task_type_2"}, + PollInterval: config.Duration{Duration: 10 * time.Second}, } configSection = pluginsConfig.MustRegisterSubSection("agent-service", &defaultConfig) @@ -71,6 +72,8 @@ type Config struct { // SupportedTaskTypes is a list of task types that are supported by this plugin. SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."` + + PollInterval config.Duration `json:"pollInterval" pflag:",The interval at which the plugin should poll the agent for metadata updates."` } type Deployment struct { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index fa8ecc7e43..adf57426a5 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -24,12 +24,12 @@ import ( ) type Registry map[string]map[int32]*Agent // map[taskTypeName][taskTypeVersion] => Agent +var agentRegistry Registry type Plugin struct { - metricScope promutils.Scope - cfg *Config - cs *ClientSet - agentRegistry Registry + metricScope promutils.Scope + cfg *Config + cs *ClientSet } type ResourceWrapper struct { @@ -91,7 +91,8 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String() taskCategory := admin.TaskCategory{Name: taskTemplate.Type, Version: taskTemplate.TaskTypeVersion} - agent, isSync := getFinalAgent(&taskCategory, p.cfg, p.agentRegistry) + logger.Infof(ctx, "AgentRegistry AgentRegistry AgentRegistry: %v", agentRegistry) + agent, isSync := getFinalAgent(&taskCategory, p.cfg) finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent) defer cancel() @@ -187,7 +188,7 @@ func (p Plugin) ExecuteTaskSync( func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg, p.agentRegistry) + agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg) client, err := p.getAsyncAgentClient(ctx, agent) if err != nil { @@ -220,7 +221,7 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return nil } metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg, p.agentRegistry) + agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg) client, err := p.getAsyncAgentClient(ctx, agent) if err != nil { @@ -319,9 +320,11 @@ func (p Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (ser func (p Plugin) watchAgents(ctx context.Context) { go wait.Until(func() { cs := createAgentClientSets(ctx) - agentRegistry := createAgentRegistry(ctx, cs) - p.agentRegistry = agentRegistry - + agentRegistry = createAgentRegistry(ctx, cs) + if d, ok := agentRegistry["airflow"]; ok { + logger.Infof(ctx, "tset: %v", d[0].AgentDeployment) + } + logger.Infof(ctx, "agentRegistry: %v", agentRegistry) }, time.Duration(5)*time.Second, ctx.Done()) } @@ -347,7 +350,9 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *fly return taskCtx.OutputWriter().Put(ctx, opReader) } -func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config, agentRegistry Registry) (*Deployment, bool) { +func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) { + logger.Infof(context.Background(), "taskCategory.Name [%v] taskCategory.Version [%v]", taskCategory.Name, taskCategory.Version) + logger.Infof(context.Background(), "agentRegistry [%v]", agentRegistry) if agent, exists := agentRegistry[taskCategory.Name][taskCategory.Version]; exists { return agent.AgentDeployment, agent.IsSync } @@ -373,17 +378,15 @@ func newAgentPlugin() webapi.PluginEntry { cs := createAgentClientSets(ctx) agentRegistry := createAgentRegistry(ctx, cs) supportedTaskTypes := append(maps.Keys(agentRegistry), cfg.SupportedTaskTypes...) - logger.Infof(context.Background(), "AgentDeployment service supports task types: %v", supportedTaskTypes) return webapi.PluginEntry{ ID: "agent-service", SupportedTaskTypes: supportedTaskTypes, PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { plugin := &Plugin{ - metricScope: iCtx.MetricsScope(), - cfg: cfg, - cs: cs, - agentRegistry: agentRegistry, + metricScope: iCtx.MetricsScope(), + cfg: cfg, + cs: cs, } plugin.watchAgents(ctx) return plugin, nil From ac2a89f4288eae812137b9874ec69d24da8090a8 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 6 Mar 2024 13:33:39 -0800 Subject: [PATCH 03/21] lint Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/config.go | 1 + flyteplugins/go/tasks/plugins/webapi/agent/plugin.go | 9 +-------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/config.go b/flyteplugins/go/tasks/plugins/webapi/agent/config.go index 5bd8e255ca..f26499f320 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/config.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/config.go @@ -73,6 +73,7 @@ type Config struct { // SupportedTaskTypes is a list of task types that are supported by this plugin. SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."` + // PollInterval is the interval at which the plugin should poll the agent for metadata updates PollInterval config.Duration `json:"pollInterval" pflag:",The interval at which the plugin should poll the agent for metadata updates."` } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index adf57426a5..79dbcd4da4 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -91,7 +91,6 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String() taskCategory := admin.TaskCategory{Name: taskTemplate.Type, Version: taskTemplate.TaskTypeVersion} - logger.Infof(ctx, "AgentRegistry AgentRegistry AgentRegistry: %v", agentRegistry) agent, isSync := getFinalAgent(&taskCategory, p.cfg) finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent) @@ -321,11 +320,7 @@ func (p Plugin) watchAgents(ctx context.Context) { go wait.Until(func() { cs := createAgentClientSets(ctx) agentRegistry = createAgentRegistry(ctx, cs) - if d, ok := agentRegistry["airflow"]; ok { - logger.Infof(ctx, "tset: %v", d[0].AgentDeployment) - } - logger.Infof(ctx, "agentRegistry: %v", agentRegistry) - }, time.Duration(5)*time.Second, ctx.Done()) + }, p.cfg.PollInterval.Duration, ctx.Done()) } func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *flyteIdl.LiteralMap) error { @@ -351,8 +346,6 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *fly } func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) { - logger.Infof(context.Background(), "taskCategory.Name [%v] taskCategory.Version [%v]", taskCategory.Name, taskCategory.Version) - logger.Infof(context.Background(), "agentRegistry [%v]", agentRegistry) if agent, exists := agentRegistry[taskCategory.Name][taskCategory.Version]; exists { return agent.AgentDeployment, agent.IsSync } From b88ecb979247176d8bf279d8b254d705ca9f7480 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 6 Mar 2024 14:13:25 -0800 Subject: [PATCH 04/21] nit Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/client.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 8378778a0f..cbb6d4312f 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -122,15 +122,17 @@ func createAgentRegistry(ctx context.Context, cs *ClientSet) Registry { grpcStatus, ok := status.FromError(err) if grpcStatus.Code() == codes.Unimplemented { // we should not panic here, as we want to continue to support old agent settings - logger.Infof(finalCtx, "list agent method not implemented for agent: [%v]", agentDeployment.Endpoint) + logger.Warningf(finalCtx, "list agent method not implemented for agent: [%v]", agentDeployment.Endpoint) continue } if !ok { logger.Warningf(finalCtx, "failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment.Endpoint, err) + continue } logger.Warningf(finalCtx, "failed to list agent: [%v] with error: [%v]", agentDeployment.Endpoint, err) + continue } for _, agent := range res.GetAgents() { From d29998e28913b20c28cc28e49bc84804bf1370a1 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 6 Mar 2024 16:26:34 -0800 Subject: [PATCH 05/21] Fix test Signed-off-by: Kevin Su --- .../tasks/plugins/webapi/agent/integration_test.go | 6 +++++- .../go/tasks/plugins/webapi/agent/plugin_test.go | 13 ++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index 689527ee3b..edae05577c 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -35,6 +35,10 @@ import ( ) func TestEndToEnd(t *testing.T) { + agentRegistry = Registry{ + "openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}}, + "spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}}, + } iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { return nil } @@ -118,6 +122,7 @@ func TestEndToEnd(t *testing.T) { cfg: GetConfig(), cs: &ClientSet{ asyncAgentClients: map[string]service.AsyncAgentServiceClient{}, + syncAgentClients: map[string]service.SyncAgentServiceClient{}, agentMetadataClients: map[string]service.AgentMetadataServiceClient{}, }, }, nil @@ -323,7 +328,6 @@ func newMockSyncAgentPlugin() webapi.PluginEntry { defaultAgentEndpoint: syncAgentClient, }, }, - agentRegistry: Registry{"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}}}, }, nil }, } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 47834a84c8..22855c8a63 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -60,15 +60,15 @@ func TestPlugin(t *testing.T) { t.Run("test getFinalAgent", func(t *testing.T) { agent := &Agent{AgentDeployment: &Deployment{Endpoint: "localhost:80"}} - agentRegistry := Registry{"spark": {defaultTaskTypeVersion: agent}} + agentRegistry = Registry{"spark": {defaultTaskTypeVersion: agent}} spark := &admin.TaskCategory{Name: "spark", Version: defaultTaskTypeVersion} foo := &admin.TaskCategory{Name: "foo", Version: defaultTaskTypeVersion} bar := &admin.TaskCategory{Name: "bar", Version: defaultTaskTypeVersion} - agentDeployment, _ := getFinalAgent(spark, &cfg, agentRegistry) + agentDeployment, _ := getFinalAgent(spark, &cfg) assert.Equal(t, agentDeployment.Endpoint, "localhost:80") - agentDeployment, _ = getFinalAgent(foo, &cfg, agentRegistry) + agentDeployment, _ = getFinalAgent(foo, &cfg) assert.Equal(t, agentDeployment.Endpoint, cfg.DefaultAgent.Endpoint) - agentDeployment, _ = getFinalAgent(bar, &cfg, agentRegistry) + agentDeployment, _ = getFinalAgent(bar, &cfg) assert.Equal(t, agentDeployment.Endpoint, cfg.DefaultAgent.Endpoint) }) @@ -307,11 +307,10 @@ func TestInitializeAgentRegistry(t *testing.T) { cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"} err := SetConfig(&cfg) assert.NoError(t, err) - agentRegistry, err := createAgentRegistry(cs) - assert.NoError(t, err) + registry := createAgentRegistry(context.Background(), cs) // In golang, the order of keys in a map is random. So, we sort the keys before asserting. - agentRegistryKeys := maps.Keys(agentRegistry) + agentRegistryKeys := maps.Keys(registry) sort.Strings(agentRegistryKeys) assert.Equal(t, agentRegistryKeys, []string{"task1", "task2", "task3"}) From 36e6f870fabcf4be2f9b23a6f27d0b941f5cc2aa Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 7 Mar 2024 00:13:37 -0800 Subject: [PATCH 06/21] lint Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/client.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index cbb6d4312f..289a1a680c 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -157,8 +157,7 @@ func createAgentRegistry(ctx context.Context, cs *ClientSet) Registry { } } } - supportedTaskTypes := append(maps.Keys(agentRegistry)) - logger.Debugf(context.Background(), "AgentDeployment service supports task types: %v", supportedTaskTypes) + logger.Debugf(context.Background(), "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry)) return agentRegistry } From 063846f1707d83095156f046754260e2cd246ea8 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 7 Mar 2024 11:34:10 -0800 Subject: [PATCH 07/21] nit Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 289a1a680c..cfe87f8dc2 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -157,7 +157,7 @@ func createAgentRegistry(ctx context.Context, cs *ClientSet) Registry { } } } - logger.Debugf(context.Background(), "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry)) + logger.Debugf(ctx, "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry)) return agentRegistry } From 7187adec3dc9d3b95369bacc1c169c88b73a059c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 7 Mar 2024 15:40:58 -0800 Subject: [PATCH 08/21] updateAgentClientSets instead Signed-off-by: Kevin Su --- .../go/tasks/plugins/webapi/agent/client.go | 23 +++++++------------ .../tasks/plugins/webapi/agent/client_test.go | 9 ++++++-- .../go/tasks/plugins/webapi/agent/plugin.go | 16 +++++++++---- .../tasks/plugins/webapi/agent/plugin_test.go | 2 +- 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index cfe87f8dc2..aa727cf0a9 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -96,7 +96,7 @@ func getFinalContext(ctx context.Context, operation string, agent *Deployment) ( return context.WithTimeout(ctx, timeout) } -func createAgentRegistry(ctx context.Context, cs *ClientSet) Registry { +func updateAgentRegistry(ctx context.Context, cs *ClientSet) Registry { agentRegistry := make(Registry) cfg := GetConfig() var agentDeployments []*Deployment @@ -161,11 +161,7 @@ func createAgentRegistry(ctx context.Context, cs *ClientSet) Registry { return agentRegistry } -func createAgentClientSets(ctx context.Context) *ClientSet { - asyncAgentClients := make(map[string]service.AsyncAgentServiceClient) - syncAgentClients := make(map[string]service.SyncAgentServiceClient) - agentMetadataClients := make(map[string]service.AgentMetadataServiceClient) - +func updateAgentClientSets(ctx context.Context, clientSet *ClientSet) { var agentDeployments []*Deployment cfg := GetConfig() @@ -174,18 +170,15 @@ func createAgentClientSets(ctx context.Context) *ClientSet { } agentDeployments = append(agentDeployments, maps.Values(cfg.AgentDeployments)...) for _, agentService := range agentDeployments { + if _, ok := clientSet.agentMetadataClients[agentService.Endpoint]; ok { + continue + } conn, err := getGrpcConnection(ctx, agentService) if err != nil { logger.Warningf(ctx, "failed to create connection to agent: [%v] with error: [%v]", agentService, err) } - syncAgentClients[agentService.Endpoint] = service.NewSyncAgentServiceClient(conn) - asyncAgentClients[agentService.Endpoint] = service.NewAsyncAgentServiceClient(conn) - agentMetadataClients[agentService.Endpoint] = service.NewAgentMetadataServiceClient(conn) - } - - return &ClientSet{ - syncAgentClients: syncAgentClients, - asyncAgentClients: asyncAgentClients, - agentMetadataClients: agentMetadataClients, + clientSet.syncAgentClients[agentService.Endpoint] = service.NewSyncAgentServiceClient(conn) + clientSet.asyncAgentClients[agentService.Endpoint] = service.NewAsyncAgentServiceClient(conn) + clientSet.agentMetadataClients[agentService.Endpoint] = service.NewAgentMetadataServiceClient(conn) } } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index c4eadb25e6..2902c2cb27 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -2,6 +2,7 @@ package agent import ( "context" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "testing" "github.com/stretchr/testify/assert" @@ -20,8 +21,12 @@ func TestInitializeClients(t *testing.T) { ctx := context.Background() err := SetConfig(&cfg) assert.NoError(t, err) - cs := createAgentClientSets(ctx) - assert.NotNil(t, cs) + cs := &ClientSet{ + asyncAgentClients: make(map[string]service.AsyncAgentServiceClient), + syncAgentClients: make(map[string]service.SyncAgentServiceClient), + agentMetadataClients: make(map[string]service.AgentMetadataServiceClient), + } + updateAgentClientSets(ctx, cs) _, ok := cs.syncAgentClients["y"] assert.True(t, ok) _, ok = cs.asyncAgentClients["x"] diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 79dbcd4da4..e2439f1c46 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -318,8 +318,8 @@ func (p Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (ser func (p Plugin) watchAgents(ctx context.Context) { go wait.Until(func() { - cs := createAgentClientSets(ctx) - agentRegistry = createAgentRegistry(ctx, cs) + updateAgentClientSets(ctx, p.cs) + agentRegistry = updateAgentRegistry(ctx, p.cs) }, p.cfg.PollInterval.Duration, ctx.Done()) } @@ -368,8 +368,14 @@ func newAgentPlugin() webapi.PluginEntry { ctx := context.Background() cfg := GetConfig() - cs := createAgentClientSets(ctx) - agentRegistry := createAgentRegistry(ctx, cs) + clientSet := &ClientSet{ + asyncAgentClients: make(map[string]service.AsyncAgentServiceClient), + syncAgentClients: make(map[string]service.SyncAgentServiceClient), + agentMetadataClients: make(map[string]service.AgentMetadataServiceClient), + } + + updateAgentClientSets(ctx, clientSet) + agentRegistry := updateAgentRegistry(ctx, clientSet) supportedTaskTypes := append(maps.Keys(agentRegistry), cfg.SupportedTaskTypes...) return webapi.PluginEntry{ @@ -379,7 +385,7 @@ func newAgentPlugin() webapi.PluginEntry { plugin := &Plugin{ metricScope: iCtx.MetricsScope(), cfg: cfg, - cs: cs, + cs: clientSet, } plugin.watchAgents(ctx) return plugin, nil diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 22855c8a63..a6ac062c70 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -307,7 +307,7 @@ func TestInitializeAgentRegistry(t *testing.T) { cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"} err := SetConfig(&cfg) assert.NoError(t, err) - registry := createAgentRegistry(context.Background(), cs) + registry := updateAgentRegistry(context.Background(), cs) // In golang, the order of keys in a map is random. So, we sort the keys before asserting. agentRegistryKeys := maps.Keys(registry) From 3dca1a7735f40607590d8058b4bb4cf7f4d47f05 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 7 Mar 2024 16:25:02 -0800 Subject: [PATCH 09/21] lock Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/plugin.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index e2439f1c46..19712dccfb 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -7,6 +7,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" "k8s.io/apimachinery/pkg/util/wait" + "sync" "time" "golang.org/x/exp/maps" @@ -25,6 +26,7 @@ import ( type Registry map[string]map[int32]*Agent // map[taskTypeName][taskTypeVersion] => Agent var agentRegistry Registry +var mu sync.RWMutex type Plugin struct { metricScope promutils.Scope @@ -318,8 +320,10 @@ func (p Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (ser func (p Plugin) watchAgents(ctx context.Context) { go wait.Until(func() { + mu.Lock() updateAgentClientSets(ctx, p.cs) agentRegistry = updateAgentRegistry(ctx, p.cs) + mu.Unlock() }, p.cfg.PollInterval.Duration, ctx.Done()) } @@ -346,9 +350,11 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *fly } func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) { + mu.RLock() if agent, exists := agentRegistry[taskCategory.Name][taskCategory.Version]; exists { return agent.AgentDeployment, agent.IsSync } + mu.RUnlock() return &cfg.DefaultAgent, false } @@ -373,9 +379,10 @@ func newAgentPlugin() webapi.PluginEntry { syncAgentClients: make(map[string]service.SyncAgentServiceClient), agentMetadataClients: make(map[string]service.AgentMetadataServiceClient), } - + mu.Lock() updateAgentClientSets(ctx, clientSet) agentRegistry := updateAgentRegistry(ctx, clientSet) + mu.Unlock() supportedTaskTypes := append(maps.Keys(agentRegistry), cfg.SupportedTaskTypes...) return webapi.PluginEntry{ From 60377d9cde349b01a3b96d26e0d0abb368413d15 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 7 Mar 2024 16:26:49 -0800 Subject: [PATCH 10/21] nit Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/plugin.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 19712dccfb..9ff641adb7 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -379,10 +379,8 @@ func newAgentPlugin() webapi.PluginEntry { syncAgentClients: make(map[string]service.SyncAgentServiceClient), agentMetadataClients: make(map[string]service.AgentMetadataServiceClient), } - mu.Lock() updateAgentClientSets(ctx, clientSet) agentRegistry := updateAgentRegistry(ctx, clientSet) - mu.Unlock() supportedTaskTypes := append(maps.Keys(agentRegistry), cfg.SupportedTaskTypes...) return webapi.PluginEntry{ From 32a907e84241e5699593d95b3ce549ce6a6d9b38 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 8 Mar 2024 12:13:02 -0800 Subject: [PATCH 11/21] defer Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/plugin.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 9ff641adb7..9f00d3331d 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -321,9 +321,9 @@ func (p Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (ser func (p Plugin) watchAgents(ctx context.Context) { go wait.Until(func() { mu.Lock() + defer mu.Unlock() updateAgentClientSets(ctx, p.cs) agentRegistry = updateAgentRegistry(ctx, p.cs) - mu.Unlock() }, p.cfg.PollInterval.Duration, ctx.Done()) } @@ -351,10 +351,10 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *fly func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) { mu.RLock() + defer mu.RUnlock() if agent, exists := agentRegistry[taskCategory.Name][taskCategory.Version]; exists { return agent.AgentDeployment, agent.IsSync } - mu.RUnlock() return &cfg.DefaultAgent, false } From 24feb98c85d347c60ad8a9f1809bee88b4cbcca8 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 8 Jun 2024 00:11:11 -0700 Subject: [PATCH 12/21] lint Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/client.go | 1 + flyteplugins/go/tasks/plugins/webapi/agent/client_test.go | 3 ++- flyteplugins/go/tasks/plugins/webapi/agent/plugin.go | 5 +++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index aa727cf0a9..18b44816a7 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -3,6 +3,7 @@ package agent import ( "context" "crypto/x509" + "golang.org/x/exp/maps" "google.golang.org/grpc" "google.golang.org/grpc/codes" diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index 2902c2cb27..cbc75b5cc4 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -2,10 +2,11 @@ package agent import ( "context" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "testing" "github.com/stretchr/testify/assert" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" ) func TestInitializeClients(t *testing.T) { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index abc3d18087..1887d5b94b 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -4,11 +4,12 @@ import ( "context" "encoding/gob" "fmt" - "golang.org/x/exp/maps" - "k8s.io/apimachinery/pkg/util/wait" "sync" "time" + "golang.org/x/exp/maps" + "k8s.io/apimachinery/pkg/util/wait" + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" From d83e51f68909bbc2d4b956f9ea7d21592fbe1953 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 8 Jun 2024 01:32:59 -0700 Subject: [PATCH 13/21] Add getter and setter Signed-off-by: Kevin Su --- .../go/tasks/plugins/webapi/agent/client.go | 15 ++++++-- .../tasks/plugins/webapi/agent/client_test.go | 2 +- .../go/tasks/plugins/webapi/agent/plugin.go | 38 +++++++++++-------- .../tasks/plugins/webapi/agent/plugin_test.go | 4 +- 4 files changed, 38 insertions(+), 21 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 18b44816a7..9580dc974a 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -97,7 +97,7 @@ func getFinalContext(ctx context.Context, operation string, agent *Deployment) ( return context.WithTimeout(ctx, timeout) } -func updateAgentRegistry(ctx context.Context, cs *ClientSet) Registry { +func updateAgentRegistry(ctx context.Context, cs *ClientSet) { agentRegistry := make(Registry) cfg := GetConfig() var agentDeployments []*Deployment @@ -159,10 +159,18 @@ func updateAgentRegistry(ctx context.Context, cs *ClientSet) Registry { } } logger.Debugf(ctx, "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry)) - return agentRegistry + SetAgentRegistry(agentRegistry) } -func updateAgentClientSets(ctx context.Context, clientSet *ClientSet) { +func initializeAgentClientSets(ctx context.Context) *ClientSet { + logger.Infof(ctx, "Initializing agent clients") + + clientSet := &ClientSet{ + asyncAgentClients: make(map[string]service.AsyncAgentServiceClient), + syncAgentClients: make(map[string]service.SyncAgentServiceClient), + agentMetadataClients: make(map[string]service.AgentMetadataServiceClient), + } + var agentDeployments []*Deployment cfg := GetConfig() @@ -182,4 +190,5 @@ func updateAgentClientSets(ctx context.Context, clientSet *ClientSet) { clientSet.asyncAgentClients[agentService.Endpoint] = service.NewAsyncAgentServiceClient(conn) clientSet.agentMetadataClients[agentService.Endpoint] = service.NewAgentMetadataServiceClient(conn) } + return clientSet } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index cbc75b5cc4..93f808e98c 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -27,7 +27,7 @@ func TestInitializeClients(t *testing.T) { syncAgentClients: make(map[string]service.SyncAgentServiceClient), agentMetadataClients: make(map[string]service.AgentMetadataServiceClient), } - updateAgentClientSets(ctx, cs) + cs = initializeAgentClientSets(ctx) _, ok := cs.syncAgentClients["y"] assert.True(t, ok) _, ok = cs.asyncAgentClients["x"] diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 1887d5b94b..8bc60b68fb 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -25,8 +25,23 @@ import ( ) type Registry map[string]map[int32]*Agent // map[taskTypeName][taskTypeVersion] => Agent -var agentRegistry Registry -var mu sync.RWMutex + +var ( + agentRegistry Registry + mu sync.RWMutex +) + +func GetAgentRegistry() Registry { + mu.Lock() + defer mu.Unlock() + return agentRegistry +} + +func SetAgentRegistry(r Registry) { + mu.Lock() + agentRegistry = r + mu.Unlock() +} type Plugin struct { metricScope promutils.Scope @@ -329,8 +344,7 @@ func (p Plugin) watchAgents(ctx context.Context) { go wait.Until(func() { mu.Lock() defer mu.Unlock() - updateAgentClientSets(ctx, p.cs) - agentRegistry = updateAgentRegistry(ctx, p.cs) + updateAgentRegistry(ctx, p.cs) }, p.cfg.PollInterval.Duration, ctx.Done()) } @@ -357,9 +371,8 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *fly } func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) { - mu.RLock() - defer mu.RUnlock() - if agent, exists := agentRegistry[taskCategory.Name][taskCategory.Version]; exists { + r := GetAgentRegistry() + if agent, exists := r[taskCategory.Name][taskCategory.Version]; exists { return agent.AgentDeployment, agent.IsSync } return &cfg.DefaultAgent, false @@ -383,14 +396,9 @@ func newAgentPlugin() webapi.PluginEntry { ctx := context.Background() cfg := GetConfig() - clientSet := &ClientSet{ - asyncAgentClients: make(map[string]service.AsyncAgentServiceClient), - syncAgentClients: make(map[string]service.SyncAgentServiceClient), - agentMetadataClients: make(map[string]service.AgentMetadataServiceClient), - } - updateAgentClientSets(ctx, clientSet) - agentRegistry := updateAgentRegistry(ctx, clientSet) - supportedTaskTypes := append(maps.Keys(agentRegistry), cfg.SupportedTaskTypes...) + clientSet := initializeAgentClientSets(ctx) + updateAgentRegistry(ctx, clientSet) + supportedTaskTypes := append(maps.Keys(GetAgentRegistry()), cfg.SupportedTaskTypes...) return webapi.PluginEntry{ ID: "agent-service", diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index bb8b5a5029..149d74646d 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -318,10 +318,10 @@ func TestInitializeAgentRegistry(t *testing.T) { cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"} err := SetConfig(&cfg) assert.NoError(t, err) - registry := updateAgentRegistry(context.Background(), cs) + updateAgentRegistry(context.Background(), cs) // In golang, the order of keys in a map is random. So, we sort the keys before asserting. - agentRegistryKeys := maps.Keys(registry) + agentRegistryKeys := maps.Keys(GetAgentRegistry()) sort.Strings(agentRegistryKeys) assert.Equal(t, agentRegistryKeys, []string{"task1", "task2", "task3"}) From 86c0ccc7697dbba14c48c74d0fbd3b61ce25d60f Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 8 Jun 2024 02:22:11 -0700 Subject: [PATCH 14/21] refactor(webapi): Improve agent client handling and logging Signed-off-by: Kevin Su --- .../go/tasks/plugins/webapi/agent/client.go | 24 ++++++++++++------- .../tasks/plugins/webapi/agent/client_test.go | 2 +- .../go/tasks/plugins/webapi/agent/plugin.go | 9 ++++--- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 9580dc974a..82fa68ca00 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -113,7 +113,11 @@ func updateAgentRegistry(ctx context.Context, cs *ClientSet) { } agentDeployments = append(agentDeployments, maps.Values(cfg.AgentDeployments)...) for _, agentDeployment := range agentDeployments { - client := cs.agentMetadataClients[agentDeployment.Endpoint] + client, ok := cs.agentMetadataClients[agentDeployment.Endpoint] + if !ok { + logger.Warningf(ctx, "Agent client not found in the clientSet for the endpoint: %v", agentDeployment.Endpoint) + continue + } finalCtx, cancel := getFinalContext(ctx, "ListAgents", agentDeployment) defer cancel() @@ -162,7 +166,7 @@ func updateAgentRegistry(ctx context.Context, cs *ClientSet) { SetAgentRegistry(agentRegistry) } -func initializeAgentClientSets(ctx context.Context) *ClientSet { +func getAgentClientSets(ctx context.Context) *ClientSet { logger.Infof(ctx, "Initializing agent clients") clientSet := &ClientSet{ @@ -178,17 +182,19 @@ func initializeAgentClientSets(ctx context.Context) *ClientSet { agentDeployments = append(agentDeployments, &cfg.DefaultAgent) } agentDeployments = append(agentDeployments, maps.Values(cfg.AgentDeployments)...) - for _, agentService := range agentDeployments { - if _, ok := clientSet.agentMetadataClients[agentService.Endpoint]; ok { + for _, agentDeployment := range agentDeployments { + if _, ok := clientSet.agentMetadataClients[agentDeployment.Endpoint]; ok { + logger.Infof(ctx, "Agent client already initialized for [%v]", agentDeployment.Endpoint) continue } - conn, err := getGrpcConnection(ctx, agentService) + conn, err := getGrpcConnection(ctx, agentDeployment) if err != nil { - logger.Warningf(ctx, "failed to create connection to agent: [%v] with error: [%v]", agentService, err) + logger.Errorf(ctx, "failed to create connection to agent: [%v] with error: [%v]", agentDeployment, err) + continue } - clientSet.syncAgentClients[agentService.Endpoint] = service.NewSyncAgentServiceClient(conn) - clientSet.asyncAgentClients[agentService.Endpoint] = service.NewAsyncAgentServiceClient(conn) - clientSet.agentMetadataClients[agentService.Endpoint] = service.NewAgentMetadataServiceClient(conn) + clientSet.syncAgentClients[agentDeployment.Endpoint] = service.NewSyncAgentServiceClient(conn) + clientSet.asyncAgentClients[agentDeployment.Endpoint] = service.NewAsyncAgentServiceClient(conn) + clientSet.agentMetadataClients[agentDeployment.Endpoint] = service.NewAgentMetadataServiceClient(conn) } return clientSet } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index 93f808e98c..4aa1405f0a 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -27,7 +27,7 @@ func TestInitializeClients(t *testing.T) { syncAgentClients: make(map[string]service.SyncAgentServiceClient), agentMetadataClients: make(map[string]service.AgentMetadataServiceClient), } - cs = initializeAgentClientSets(ctx) + cs = getAgentClientSets(ctx) _, ok := cs.syncAgentClients["y"] assert.True(t, ok) _, ok = cs.asyncAgentClients["x"] diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 8bc60b68fb..14e2a26db2 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -39,8 +39,8 @@ func GetAgentRegistry() Registry { func SetAgentRegistry(r Registry) { mu.Lock() + defer mu.Unlock() agentRegistry = r - mu.Unlock() } type Plugin struct { @@ -342,9 +342,8 @@ func (p Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (ser func (p Plugin) watchAgents(ctx context.Context) { go wait.Until(func() { - mu.Lock() - defer mu.Unlock() - updateAgentRegistry(ctx, p.cs) + clientSet := getAgentClientSets(ctx) + updateAgentRegistry(ctx, clientSet) }, p.cfg.PollInterval.Duration, ctx.Done()) } @@ -396,7 +395,7 @@ func newAgentPlugin() webapi.PluginEntry { ctx := context.Background() cfg := GetConfig() - clientSet := initializeAgentClientSets(ctx) + clientSet := getAgentClientSets(ctx) updateAgentRegistry(ctx, clientSet) supportedTaskTypes := append(maps.Keys(GetAgentRegistry()), cfg.SupportedTaskTypes...) From 4e9e7fb87c1903308517cc5752d792af2319a16b Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 8 Jun 2024 02:26:40 -0700 Subject: [PATCH 15/21] errorf Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 82fa68ca00..250b33f7d9 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -132,11 +132,11 @@ func updateAgentRegistry(ctx context.Context, cs *ClientSet) { } if !ok { - logger.Warningf(finalCtx, "failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment.Endpoint, err) + logger.Errorf(finalCtx, "failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment.Endpoint, err) continue } - logger.Warningf(finalCtx, "failed to list agent: [%v] with error: [%v]", agentDeployment.Endpoint, err) + logger.Errorf(finalCtx, "failed to list agent: [%v] with error: [%v]", agentDeployment.Endpoint, err) continue } From 8e3e80572a945b845126848aac5a502a40d10411 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 8 Jun 2024 02:35:33 -0700 Subject: [PATCH 16/21] update Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/client.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 250b33f7d9..c56b3d97b4 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -155,10 +155,11 @@ func updateAgentRegistry(ctx context.Context, cs *ClientSet) { } // If the agent doesn't implement the metadata service, we construct the registry based on the configuration for taskType, agentDeploymentID := range cfg.AgentForTaskTypes { - agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID] - if ok { + if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok { agent := &Agent{AgentDeployment: agentDeployment, IsSync: false} - agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent} + if _, ok := agentRegistry[taskType]; !ok { + agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent} + } } } } From a1d2a2c6c747696b1ec1f2fa88e9ef273f5b652d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 8 Jun 2024 02:40:29 -0700 Subject: [PATCH 17/21] remove logger Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/client.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index c56b3d97b4..71db78f2ae 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -168,8 +168,6 @@ func updateAgentRegistry(ctx context.Context, cs *ClientSet) { } func getAgentClientSets(ctx context.Context) *ClientSet { - logger.Infof(ctx, "Initializing agent clients") - clientSet := &ClientSet{ asyncAgentClients: make(map[string]service.AsyncAgentServiceClient), syncAgentClients: make(map[string]service.SyncAgentServiceClient), From ac0641561e47db53adc219852a4d9dee540f017d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 8 Jun 2024 02:44:39 -0700 Subject: [PATCH 18/21] nit Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 71db78f2ae..2ec301bebe 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -156,8 +156,8 @@ func updateAgentRegistry(ctx context.Context, cs *ClientSet) { // If the agent doesn't implement the metadata service, we construct the registry based on the configuration for taskType, agentDeploymentID := range cfg.AgentForTaskTypes { if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok { - agent := &Agent{AgentDeployment: agentDeployment, IsSync: false} if _, ok := agentRegistry[taskType]; !ok { + agent := &Agent{AgentDeployment: agentDeployment, IsSync: false} agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent} } } From 6fb31abadd6843c809f5d1a54962634b7f905360 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 8 Jun 2024 02:46:59 -0700 Subject: [PATCH 19/21] nit Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/client.go | 2 +- flyteplugins/go/tasks/plugins/webapi/agent/plugin.go | 8 ++++---- flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 2ec301bebe..d8c8b055dc 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -164,7 +164,7 @@ func updateAgentRegistry(ctx context.Context, cs *ClientSet) { } } logger.Debugf(ctx, "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry)) - SetAgentRegistry(agentRegistry) + setAgentRegistry(agentRegistry) } func getAgentClientSets(ctx context.Context) *ClientSet { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 14e2a26db2..368dffcef4 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -31,13 +31,13 @@ var ( mu sync.RWMutex ) -func GetAgentRegistry() Registry { +func getAgentRegistry() Registry { mu.Lock() defer mu.Unlock() return agentRegistry } -func SetAgentRegistry(r Registry) { +func setAgentRegistry(r Registry) { mu.Lock() defer mu.Unlock() agentRegistry = r @@ -370,7 +370,7 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *fly } func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) { - r := GetAgentRegistry() + r := getAgentRegistry() if agent, exists := r[taskCategory.Name][taskCategory.Version]; exists { return agent.AgentDeployment, agent.IsSync } @@ -397,7 +397,7 @@ func newAgentPlugin() webapi.PluginEntry { clientSet := getAgentClientSets(ctx) updateAgentRegistry(ctx, clientSet) - supportedTaskTypes := append(maps.Keys(GetAgentRegistry()), cfg.SupportedTaskTypes...) + supportedTaskTypes := append(maps.Keys(getAgentRegistry()), cfg.SupportedTaskTypes...) return webapi.PluginEntry{ ID: "agent-service", diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 149d74646d..19af85eed3 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -321,7 +321,7 @@ func TestInitializeAgentRegistry(t *testing.T) { updateAgentRegistry(context.Background(), cs) // In golang, the order of keys in a map is random. So, we sort the keys before asserting. - agentRegistryKeys := maps.Keys(GetAgentRegistry()) + agentRegistryKeys := maps.Keys(getAgentRegistry()) sort.Strings(agentRegistryKeys) assert.Equal(t, agentRegistryKeys, []string{"task1", "task2", "task3"}) From 26731cd20b1cb5e39be2836cb866a15b8bb546c3 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 8 Jun 2024 02:57:16 -0700 Subject: [PATCH 20/21] lint Signed-off-by: Kevin Su --- .../go/tasks/plugins/webapi/agent/client_test.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index 4aa1405f0a..7b0df4fdec 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -5,8 +5,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" ) func TestInitializeClients(t *testing.T) { @@ -22,12 +20,8 @@ func TestInitializeClients(t *testing.T) { ctx := context.Background() err := SetConfig(&cfg) assert.NoError(t, err) - cs := &ClientSet{ - asyncAgentClients: make(map[string]service.AsyncAgentServiceClient), - syncAgentClients: make(map[string]service.SyncAgentServiceClient), - agentMetadataClients: make(map[string]service.AgentMetadataServiceClient), - } - cs = getAgentClientSets(ctx) + cs : + = getAgentClientSets(ctx) _, ok := cs.syncAgentClients["y"] assert.True(t, ok) _, ok = cs.asyncAgentClients["x"] From e91de3bd0cc503d22e236bf57d724c314cbc30f5 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 8 Jun 2024 02:57:35 -0700 Subject: [PATCH 21/21] lint Signed-off-by: Kevin Su --- flyteplugins/go/tasks/plugins/webapi/agent/client_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index 7b0df4fdec..1850f2128f 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -20,8 +20,7 @@ func TestInitializeClients(t *testing.T) { ctx := context.Background() err := SetConfig(&cfg) assert.NoError(t, err) - cs : - = getAgentClientSets(ctx) + cs := getAgentClientSets(ctx) _, ok := cs.syncAgentClients["y"] assert.True(t, ok) _, ok = cs.asyncAgentClients["x"]