diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client.go b/flyteplugins/go/tasks/plugins/webapi/agent/client.go index 9f4409c8e6..d8c8b055dc 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client.go @@ -3,7 +3,6 @@ package agent import ( "context" "crypto/x509" - "fmt" "golang.org/x/exp/maps" "google.golang.org/grpc" @@ -98,8 +97,7 @@ func getFinalContext(ctx context.Context, operation string, agent *Deployment) ( return context.WithTimeout(ctx, timeout) } -func initializeAgentRegistry(cs *ClientSet) (Registry, error) { - logger.Infof(context.Background(), "Initializing agent registry") +func updateAgentRegistry(ctx context.Context, cs *ClientSet) { agentRegistry := make(Registry) cfg := GetConfig() var agentDeployments []*Deployment @@ -115,9 +113,13 @@ func initializeAgentRegistry(cs *ClientSet) (Registry, error) { } agentDeployments = append(agentDeployments, maps.Values(cfg.AgentDeployments)...) for _, agentDeployment := range agentDeployments { - client := cs.agentMetadataClients[agentDeployment.Endpoint] + client, ok := cs.agentMetadataClients[agentDeployment.Endpoint] + if !ok { + logger.Warningf(ctx, "Agent client not found in the clientSet for the endpoint: %v", agentDeployment.Endpoint) + continue + } - finalCtx, cancel := getFinalContext(context.Background(), "ListAgents", agentDeployment) + finalCtx, cancel := getFinalContext(ctx, "ListAgents", agentDeployment) defer cancel() res, err := client.ListAgents(finalCtx, &admin.ListAgentsRequest{}) @@ -125,15 +127,17 @@ func initializeAgentRegistry(cs *ClientSet) (Registry, error) { grpcStatus, ok := status.FromError(err) if grpcStatus.Code() == codes.Unimplemented { // we should not panic here, as we want to continue to support old agent settings - logger.Infof(context.Background(), "list agent method not implemented for agent: [%v]", agentDeployment) + logger.Warningf(finalCtx, "list agent method not implemented for agent: [%v]", agentDeployment.Endpoint) continue } if !ok { - return nil, fmt.Errorf("failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment, err) + logger.Errorf(finalCtx, "failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment.Endpoint, err) + continue } - return nil, fmt.Errorf("failed to list agent: [%v] with error: [%v]", agentDeployment, err) + logger.Errorf(finalCtx, "failed to list agent: [%v] with error: [%v]", agentDeployment.Endpoint, err) + continue } for _, agent := range res.GetAgents() { @@ -148,20 +152,27 @@ func initializeAgentRegistry(cs *ClientSet) (Registry, error) { agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync} agentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent} } - logger.Infof(context.Background(), "[%v] is a sync agent: [%v]", agent.Name, agent.IsSync) - logger.Infof(context.Background(), "[%v] supports task category: [%v]", agent.Name, supportedTaskCategories) + } + // If the agent doesn't implement the metadata service, we construct the registry based on the configuration + for taskType, agentDeploymentID := range cfg.AgentForTaskTypes { + if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok { + if _, ok := agentRegistry[taskType]; !ok { + agent := &Agent{AgentDeployment: agentDeployment, IsSync: false} + agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent} + } + } } } - - return agentRegistry, nil + logger.Debugf(ctx, "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry)) + setAgentRegistry(agentRegistry) } -func initializeClients(ctx context.Context) (*ClientSet, error) { - logger.Infof(ctx, "Initializing agent clients") - - asyncAgentClients := make(map[string]service.AsyncAgentServiceClient) - syncAgentClients := make(map[string]service.SyncAgentServiceClient) - agentMetadataClients := make(map[string]service.AgentMetadataServiceClient) +func getAgentClientSets(ctx context.Context) *ClientSet { + clientSet := &ClientSet{ + asyncAgentClients: make(map[string]service.AsyncAgentServiceClient), + syncAgentClients: make(map[string]service.SyncAgentServiceClient), + agentMetadataClients: make(map[string]service.AgentMetadataServiceClient), + } var agentDeployments []*Deployment cfg := GetConfig() @@ -170,19 +181,19 @@ func initializeClients(ctx context.Context) (*ClientSet, error) { agentDeployments = append(agentDeployments, &cfg.DefaultAgent) } agentDeployments = append(agentDeployments, maps.Values(cfg.AgentDeployments)...) - for _, agentService := range agentDeployments { - conn, err := getGrpcConnection(ctx, agentService) + for _, agentDeployment := range agentDeployments { + if _, ok := clientSet.agentMetadataClients[agentDeployment.Endpoint]; ok { + logger.Infof(ctx, "Agent client already initialized for [%v]", agentDeployment.Endpoint) + continue + } + conn, err := getGrpcConnection(ctx, agentDeployment) if err != nil { - return nil, err + logger.Errorf(ctx, "failed to create connection to agent: [%v] with error: [%v]", agentDeployment, err) + continue } - syncAgentClients[agentService.Endpoint] = service.NewSyncAgentServiceClient(conn) - asyncAgentClients[agentService.Endpoint] = service.NewAsyncAgentServiceClient(conn) - agentMetadataClients[agentService.Endpoint] = service.NewAgentMetadataServiceClient(conn) + clientSet.syncAgentClients[agentDeployment.Endpoint] = service.NewSyncAgentServiceClient(conn) + clientSet.asyncAgentClients[agentDeployment.Endpoint] = service.NewAsyncAgentServiceClient(conn) + clientSet.agentMetadataClients[agentDeployment.Endpoint] = service.NewAgentMetadataServiceClient(conn) } - - return &ClientSet{ - syncAgentClients: syncAgentClients, - asyncAgentClients: asyncAgentClients, - agentMetadataClients: agentMetadataClients, - }, nil + return clientSet } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index 4ad7f8cbaa..1850f2128f 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -20,9 +20,7 @@ func TestInitializeClients(t *testing.T) { ctx := context.Background() err := SetConfig(&cfg) assert.NoError(t, err) - cs, err := initializeClients(ctx) - assert.NoError(t, err) - assert.NotNil(t, cs) + cs := getAgentClientSets(ctx) _, ok := cs.syncAgentClients["y"] assert.True(t, ok) _, ok = cs.asyncAgentClients["x"] diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/config.go b/flyteplugins/go/tasks/plugins/webapi/agent/config.go index 3f9fd354b6..f26499f320 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/config.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/config.go @@ -47,6 +47,7 @@ var ( // AsyncPlugin should be registered to at least one task type. // Reference: https://github.com/flyteorg/flyte/blob/master/flyteplugins/go/tasks/pluginmachinery/registry.go#L27 SupportedTaskTypes: []string{"task_type_1", "task_type_2"}, + PollInterval: config.Duration{Duration: 10 * time.Second}, } configSection = pluginsConfig.MustRegisterSubSection("agent-service", &defaultConfig) @@ -71,6 +72,9 @@ type Config struct { // SupportedTaskTypes is a list of task types that are supported by this plugin. SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."` + + // PollInterval is the interval at which the plugin should poll the agent for metadata updates + PollInterval config.Duration `json:"pollInterval" pflag:",The interval at which the plugin should poll the agent for metadata updates."` } type Deployment struct { diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index 6fb3828c0c..f3e626524c 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -35,6 +35,10 @@ import ( ) func TestEndToEnd(t *testing.T) { + agentRegistry = Registry{ + "openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}}, + "spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}}, + } iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error { return nil } @@ -118,6 +122,7 @@ func TestEndToEnd(t *testing.T) { cfg: GetConfig(), cs: &ClientSet{ asyncAgentClients: map[string]service.AsyncAgentServiceClient{}, + syncAgentClients: map[string]service.SyncAgentServiceClient{}, agentMetadataClients: map[string]service.AgentMetadataServiceClient{}, }, }, nil @@ -326,7 +331,6 @@ func newMockSyncAgentPlugin() webapi.PluginEntry { defaultAgentEndpoint: syncAgentClient, }, }, - agentRegistry: Registry{"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}}}, }, nil }, } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index e41c4ccaa0..368dffcef4 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -4,9 +4,11 @@ import ( "context" "encoding/gob" "fmt" + "sync" "time" "golang.org/x/exp/maps" + "k8s.io/apimachinery/pkg/util/wait" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" @@ -24,11 +26,27 @@ import ( type Registry map[string]map[int32]*Agent // map[taskTypeName][taskTypeVersion] => Agent -type Plugin struct { - metricScope promutils.Scope - cfg *Config - cs *ClientSet +var ( agentRegistry Registry + mu sync.RWMutex +) + +func getAgentRegistry() Registry { + mu.Lock() + defer mu.Unlock() + return agentRegistry +} + +func setAgentRegistry(r Registry) { + mu.Lock() + defer mu.Unlock() + agentRegistry = r +} + +type Plugin struct { + metricScope promutils.Scope + cfg *Config + cs *ClientSet } type ResourceWrapper struct { @@ -95,7 +113,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String() taskCategory := admin.TaskCategory{Name: taskTemplate.Type, Version: taskTemplate.TaskTypeVersion} - agent, isSync := getFinalAgent(&taskCategory, p.cfg, p.agentRegistry) + agent, isSync := getFinalAgent(&taskCategory, p.cfg) taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) @@ -193,7 +211,7 @@ func (p Plugin) ExecuteTaskSync( func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg, p.agentRegistry) + agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg) client, err := p.getAsyncAgentClient(ctx, agent) if err != nil { @@ -226,7 +244,7 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error return nil } metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) - agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg, p.agentRegistry) + agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg) client, err := p.getAsyncAgentClient(ctx, agent) if err != nil { @@ -322,6 +340,13 @@ func (p Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (ser return client, nil } +func (p Plugin) watchAgents(ctx context.Context) { + go wait.Until(func() { + clientSet := getAgentClientSets(ctx) + updateAgentRegistry(ctx, clientSet) + }, p.cfg.PollInterval.Duration, ctx.Done()) +} + func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *flyteIdl.LiteralMap) error { taskTemplate, err := taskCtx.TaskReader().Read(ctx) if err != nil { @@ -344,11 +369,11 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *fly return taskCtx.OutputWriter().Put(ctx, opReader) } -func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config, agentRegistry Registry) (*Deployment, bool) { - if agent, exists := agentRegistry[taskCategory.Name][taskCategory.Version]; exists { +func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) { + r := getAgentRegistry() + if agent, exists := r[taskCategory.Name][taskCategory.Version]; exists { return agent.AgentDeployment, agent.IsSync } - return &cfg.DefaultAgent, false } @@ -367,31 +392,24 @@ func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata } func newAgentPlugin() webapi.PluginEntry { - cs, err := initializeClients(context.Background()) - if err != nil { - // We should wait for all agents to be up and running before starting the server - panic(fmt.Sprintf("failed to initialize clients with error: %v", err)) - } - - agentRegistry, err := initializeAgentRegistry(cs) - if err != nil { - panic(fmt.Sprintf("failed to initialize agent registry with error: %v", err)) - } - + ctx := context.Background() cfg := GetConfig() - supportedTaskTypes := append(maps.Keys(agentRegistry), cfg.SupportedTaskTypes...) - logger.Infof(context.Background(), "AgentDeployment service supports task types: %v", supportedTaskTypes) + + clientSet := getAgentClientSets(ctx) + updateAgentRegistry(ctx, clientSet) + supportedTaskTypes := append(maps.Keys(getAgentRegistry()), cfg.SupportedTaskTypes...) return webapi.PluginEntry{ ID: "agent-service", SupportedTaskTypes: supportedTaskTypes, PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { - return &Plugin{ - metricScope: iCtx.MetricsScope(), - cfg: cfg, - cs: cs, - agentRegistry: agentRegistry, - }, nil + plugin := &Plugin{ + metricScope: iCtx.MetricsScope(), + cfg: cfg, + cs: clientSet, + } + plugin.watchAgents(ctx) + return plugin, nil }, } } @@ -399,6 +417,5 @@ func newAgentPlugin() webapi.PluginEntry { func RegisterAgentPlugin() { gob.Register(ResourceMetaWrapper{}) gob.Register(ResourceWrapper{}) - pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin()) } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 3e8cb882c8..19af85eed3 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -60,15 +60,15 @@ func TestPlugin(t *testing.T) { t.Run("test getFinalAgent", func(t *testing.T) { agent := &Agent{AgentDeployment: &Deployment{Endpoint: "localhost:80"}} - agentRegistry := Registry{"spark": {defaultTaskTypeVersion: agent}} + agentRegistry = Registry{"spark": {defaultTaskTypeVersion: agent}} spark := &admin.TaskCategory{Name: "spark", Version: defaultTaskTypeVersion} foo := &admin.TaskCategory{Name: "foo", Version: defaultTaskTypeVersion} bar := &admin.TaskCategory{Name: "bar", Version: defaultTaskTypeVersion} - agentDeployment, _ := getFinalAgent(spark, &cfg, agentRegistry) + agentDeployment, _ := getFinalAgent(spark, &cfg) assert.Equal(t, agentDeployment.Endpoint, "localhost:80") - agentDeployment, _ = getFinalAgent(foo, &cfg, agentRegistry) + agentDeployment, _ = getFinalAgent(foo, &cfg) assert.Equal(t, agentDeployment.Endpoint, cfg.DefaultAgent.Endpoint) - agentDeployment, _ = getFinalAgent(bar, &cfg, agentRegistry) + agentDeployment, _ = getFinalAgent(bar, &cfg) assert.Equal(t, agentDeployment.Endpoint, cfg.DefaultAgent.Endpoint) }) @@ -318,11 +318,10 @@ func TestInitializeAgentRegistry(t *testing.T) { cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"} err := SetConfig(&cfg) assert.NoError(t, err) - agentRegistry, err := initializeAgentRegistry(cs) - assert.NoError(t, err) + updateAgentRegistry(context.Background(), cs) // In golang, the order of keys in a map is random. So, we sort the keys before asserting. - agentRegistryKeys := maps.Keys(agentRegistry) + agentRegistryKeys := maps.Keys(getAgentRegistry()) sort.Strings(agentRegistryKeys) assert.Equal(t, agentRegistryKeys, []string{"task1", "task2", "task3"})