Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Watch agent metadata service #5017

Merged
merged 22 commits into from
Jun 8, 2024
34 changes: 20 additions & 14 deletions flyteplugins/go/tasks/plugins/webapi/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import (
"context"
"crypto/x509"
"fmt"

"golang.org/x/exp/maps"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -98,7 +96,7 @@
return context.WithTimeout(ctx, timeout)
}

func initializeAgentRegistry(cs *ClientSet) (Registry, error) {
func createAgentRegistry(ctx context.Context, cs *ClientSet) Registry {
agentRegistry := make(Registry)
cfg := GetConfig()
var agentDeployments []*Deployment
Expand All @@ -116,23 +114,25 @@
for _, agentDeployment := range agentDeployments {
client := cs.agentMetadataClients[agentDeployment.Endpoint]

finalCtx, cancel := getFinalContext(context.Background(), "ListAgents", agentDeployment)
finalCtx, cancel := getFinalContext(ctx, "ListAgents", agentDeployment)
defer cancel()

res, err := client.ListAgents(finalCtx, &admin.ListAgentsRequest{})
if err != nil {
grpcStatus, ok := status.FromError(err)
if grpcStatus.Code() == codes.Unimplemented {
// we should not panic here, as we want to continue to support old agent settings
logger.Infof(context.Background(), "list agent method not implemented for agent: [%v]", agentDeployment)
logger.Warningf(finalCtx, "list agent method not implemented for agent: [%v]", agentDeployment.Endpoint)

Check warning on line 125 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L125

Added line #L125 was not covered by tests
continue
}

if !ok {
return nil, fmt.Errorf("failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment, err)
logger.Warningf(finalCtx, "failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment.Endpoint, err)
continue

Check warning on line 131 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L130-L131

Added lines #L130 - L131 were not covered by tests
}

return nil, fmt.Errorf("failed to list agent: [%v] with error: [%v]", agentDeployment, err)
logger.Warningf(finalCtx, "failed to list agent: [%v] with error: [%v]", agentDeployment.Endpoint, err)
continue

Check warning on line 135 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L134-L135

Added lines #L134 - L135 were not covered by tests
}

for _, agent := range res.GetAgents() {
Expand All @@ -147,15 +147,21 @@
agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync}
agentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent}
}
logger.Infof(context.Background(), "[%v] is a sync agent: [%v]", agent.Name, agent.IsSync)
logger.Infof(context.Background(), "[%v] supports task category: [%v]", agent.Name, supportedTaskCategories)
}
// If the agent doesn't implement the metadata service, we construct the registry based on the configuration
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]
if ok {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: false}
agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
}

Check warning on line 157 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L155-L157

Added lines #L155 - L157 were not covered by tests
}
}

return agentRegistry, nil
logger.Debugf(context.Background(), "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry))
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
return agentRegistry
}

func initializeClients(ctx context.Context) (*ClientSet, error) {
func createAgentClientSets(ctx context.Context) *ClientSet {
asyncAgentClients := make(map[string]service.AsyncAgentServiceClient)
syncAgentClients := make(map[string]service.SyncAgentServiceClient)
agentMetadataClients := make(map[string]service.AgentMetadataServiceClient)
Expand All @@ -170,7 +176,7 @@
for _, agentService := range agentDeployments {
conn, err := getGrpcConnection(ctx, agentService)
if err != nil {
return nil, err
logger.Warningf(ctx, "failed to create connection to agent: [%v] with error: [%v]", agentService, err)

Check warning on line 179 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L179

Added line #L179 was not covered by tests
}
syncAgentClients[agentService.Endpoint] = service.NewSyncAgentServiceClient(conn)
asyncAgentClients[agentService.Endpoint] = service.NewAsyncAgentServiceClient(conn)
Expand All @@ -181,5 +187,5 @@
syncAgentClients: syncAgentClients,
asyncAgentClients: asyncAgentClients,
agentMetadataClients: agentMetadataClients,
}, nil
}
}
3 changes: 1 addition & 2 deletions flyteplugins/go/tasks/plugins/webapi/agent/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ func TestInitializeClients(t *testing.T) {
ctx := context.Background()
err := SetConfig(&cfg)
assert.NoError(t, err)
cs, err := initializeClients(ctx)
assert.NoError(t, err)
cs := createAgentClientSets(ctx)
assert.NotNil(t, cs)
_, ok := cs.syncAgentClients["y"]
assert.True(t, ok)
Expand Down
4 changes: 4 additions & 0 deletions flyteplugins/go/tasks/plugins/webapi/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ var (
// AsyncPlugin should be registered to at least one task type.
// Reference: https://github.com/flyteorg/flyte/blob/master/flyteplugins/go/tasks/pluginmachinery/registry.go#L27
SupportedTaskTypes: []string{"task_type_1", "task_type_2"},
PollInterval: config.Duration{Duration: 10 * time.Second},
}

configSection = pluginsConfig.MustRegisterSubSection("agent-service", &defaultConfig)
Expand All @@ -71,6 +72,9 @@ type Config struct {

// SupportedTaskTypes is a list of task types that are supported by this plugin.
SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."`

// PollInterval is the interval at which the plugin should poll the agent for metadata updates
PollInterval config.Duration `json:"pollInterval" pflag:",The interval at which the plugin should poll the agent for metadata updates."`
}

type Deployment struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ import (
)

func TestEndToEnd(t *testing.T) {
agentRegistry = Registry{
"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}},
"spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}},
}
iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error {
return nil
}
Expand Down Expand Up @@ -118,6 +122,7 @@ func TestEndToEnd(t *testing.T) {
cfg: GetConfig(),
cs: &ClientSet{
asyncAgentClients: map[string]service.AsyncAgentServiceClient{},
syncAgentClients: map[string]service.SyncAgentServiceClient{},
agentMetadataClients: map[string]service.AgentMetadataServiceClient{},
},
}, nil
Expand Down Expand Up @@ -323,7 +328,6 @@ func newMockSyncAgentPlugin() webapi.PluginEntry {
defaultAgentEndpoint: syncAgentClient,
},
},
agentRegistry: Registry{"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}}},
}, nil
},
}
Expand Down
57 changes: 28 additions & 29 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
"encoding/gob"
"fmt"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery"
"k8s.io/apimachinery/pkg/util/wait"
"time"

"golang.org/x/exp/maps"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
pluginErrors "github.com/flyteorg/flyte/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/template"
flyteIO "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io"
Expand All @@ -23,12 +24,12 @@
)

type Registry map[string]map[int32]*Agent // map[taskTypeName][taskTypeVersion] => Agent
var agentRegistry Registry

type Plugin struct {
metricScope promutils.Scope
cfg *Config
cs *ClientSet
agentRegistry Registry
metricScope promutils.Scope
cfg *Config
cs *ClientSet
}

type ResourceWrapper struct {
Expand Down Expand Up @@ -90,7 +91,7 @@
outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String()

taskCategory := admin.TaskCategory{Name: taskTemplate.Type, Version: taskTemplate.TaskTypeVersion}
agent, isSync := getFinalAgent(&taskCategory, p.cfg, p.agentRegistry)
agent, isSync := getFinalAgent(&taskCategory, p.cfg)

finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent)
defer cancel()
Expand Down Expand Up @@ -186,7 +187,7 @@

func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) {
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg, p.agentRegistry)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg)

client, err := p.getAsyncAgentClient(ctx, agent)
if err != nil {
Expand Down Expand Up @@ -219,7 +220,7 @@
return nil
}
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg, p.agentRegistry)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg)

Check warning on line 223 in flyteplugins/go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/plugin.go#L223

Added line #L223 was not covered by tests

client, err := p.getAsyncAgentClient(ctx, agent)
if err != nil {
Expand Down Expand Up @@ -315,6 +316,13 @@
return client, nil
}

func (p Plugin) watchAgents(ctx context.Context) {
go wait.Until(func() {
cs := createAgentClientSets(ctx)
agentRegistry = createAgentRegistry(ctx, cs)
}, p.cfg.PollInterval.Duration, ctx.Done())

Check warning on line 323 in flyteplugins/go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/plugin.go#L319-L323

Added lines #L319 - L323 were not covered by tests
}

func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *flyteIdl.LiteralMap) error {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
Expand All @@ -337,11 +345,10 @@
return taskCtx.OutputWriter().Put(ctx, opReader)
}

func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config, agentRegistry Registry) (*Deployment, bool) {
func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) {
if agent, exists := agentRegistry[taskCategory.Name][taskCategory.Version]; exists {
return agent.AgentDeployment, agent.IsSync
}

return &cfg.DefaultAgent, false
}

Expand All @@ -358,38 +365,30 @@
}

func newAgentPlugin() webapi.PluginEntry {
cs, err := initializeClients(context.Background())
if err != nil {
// We should wait for all agents to be up and running before starting the server
panic(fmt.Sprintf("failed to initialize clients with error: %v", err))
}

agentRegistry, err := initializeAgentRegistry(cs)
if err != nil {
panic(fmt.Sprintf("failed to initialize agent registry with error: %v", err))
}

ctx := context.Background()

Check warning on line 368 in flyteplugins/go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/plugin.go#L368

Added line #L368 was not covered by tests
cfg := GetConfig()

cs := createAgentClientSets(ctx)
agentRegistry := createAgentRegistry(ctx, cs)

Check warning on line 372 in flyteplugins/go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/plugin.go#L370-L372

Added lines #L370 - L372 were not covered by tests
supportedTaskTypes := append(maps.Keys(agentRegistry), cfg.SupportedTaskTypes...)
logger.Infof(context.Background(), "AgentDeployment service supports task types: %v", supportedTaskTypes)

return webapi.PluginEntry{
ID: "agent-service",
SupportedTaskTypes: supportedTaskTypes,
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: cfg,
cs: cs,
agentRegistry: agentRegistry,
}, nil
plugin := &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: cfg,
cs: cs,
}
plugin.watchAgents(ctx)
return plugin, nil

Check warning on line 385 in flyteplugins/go/tasks/plugins/webapi/agent/plugin.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/plugin.go#L379-L385

Added lines #L379 - L385 were not covered by tests
},
}
}

func RegisterAgentPlugin() {
gob.Register(ResourceMetaWrapper{})
gob.Register(ResourceWrapper{})

pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin())
}
13 changes: 6 additions & 7 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ func TestPlugin(t *testing.T) {

t.Run("test getFinalAgent", func(t *testing.T) {
agent := &Agent{AgentDeployment: &Deployment{Endpoint: "localhost:80"}}
agentRegistry := Registry{"spark": {defaultTaskTypeVersion: agent}}
agentRegistry = Registry{"spark": {defaultTaskTypeVersion: agent}}
spark := &admin.TaskCategory{Name: "spark", Version: defaultTaskTypeVersion}
foo := &admin.TaskCategory{Name: "foo", Version: defaultTaskTypeVersion}
bar := &admin.TaskCategory{Name: "bar", Version: defaultTaskTypeVersion}
agentDeployment, _ := getFinalAgent(spark, &cfg, agentRegistry)
agentDeployment, _ := getFinalAgent(spark, &cfg)
assert.Equal(t, agentDeployment.Endpoint, "localhost:80")
agentDeployment, _ = getFinalAgent(foo, &cfg, agentRegistry)
agentDeployment, _ = getFinalAgent(foo, &cfg)
assert.Equal(t, agentDeployment.Endpoint, cfg.DefaultAgent.Endpoint)
agentDeployment, _ = getFinalAgent(bar, &cfg, agentRegistry)
agentDeployment, _ = getFinalAgent(bar, &cfg)
assert.Equal(t, agentDeployment.Endpoint, cfg.DefaultAgent.Endpoint)
})

Expand Down Expand Up @@ -307,11 +307,10 @@ func TestInitializeAgentRegistry(t *testing.T) {
cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"}
err := SetConfig(&cfg)
assert.NoError(t, err)
agentRegistry, err := initializeAgentRegistry(cs)
assert.NoError(t, err)
registry := createAgentRegistry(context.Background(), cs)

// In golang, the order of keys in a map is random. So, we sort the keys before asserting.
agentRegistryKeys := maps.Keys(agentRegistry)
agentRegistryKeys := maps.Keys(registry)
sort.Strings(agentRegistryKeys)

assert.Equal(t, agentRegistryKeys, []string{"task1", "task2", "task3"})
Expand Down
Loading