Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Agent ClientSet #4718

Merged
merged 13 commits into from
Jan 31, 2024
159 changes: 159 additions & 0 deletions flyteplugins/go/tasks/plugins/webapi/agent/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package agent

import (
"context"
"crypto/x509"
"fmt"

"golang.org/x/exp/maps"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/status"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
)

// ClientSet contains the clients exposed to communicate with various agent services.
type ClientSet struct {
agentClients map[string]service.AsyncAgentServiceClient // map[endpoint] => client
agentMetadataClients map[string]service.AgentMetadataServiceClient // map[endpoint] => client
}

func getGrpcConnection(ctx context.Context, agent *Agent) (*grpc.ClientConn, error) {
var opts []grpc.DialOption

if agent.Insecure {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
} else {
pool, err := x509.SystemCertPool()
if err != nil {
return nil, err
}

Check warning on line 37 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L28-L37

Added lines #L28 - L37 were not covered by tests

creds := credentials.NewClientTLSFromCert(pool, "")
opts = append(opts, grpc.WithTransportCredentials(creds))

Check warning on line 40 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L39-L40

Added lines #L39 - L40 were not covered by tests
}

if len(agent.DefaultServiceConfig) != 0 {
opts = append(opts, grpc.WithDefaultServiceConfig(agent.DefaultServiceConfig))
}

Check warning on line 45 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L43-L45

Added lines #L43 - L45 were not covered by tests

var err error
conn, err := grpc.Dial(agent.Endpoint, opts...)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
if cerr := conn.Close(); cerr != nil {
grpclog.Infof("Failed to close conn to %s: %v", agent, cerr)
}
return

Check warning on line 57 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L47-L57

Added lines #L47 - L57 were not covered by tests
}
go func() {
<-ctx.Done()
if cerr := conn.Close(); cerr != nil {
grpclog.Infof("Failed to close conn to %s: %v", agent, cerr)
}

Check warning on line 63 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L59-L63

Added lines #L59 - L63 were not covered by tests
}()
}()

return conn, nil

Check warning on line 67 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L67

Added line #L67 was not covered by tests
}

func getFinalTimeout(operation string, agent *Agent) config.Duration {
if t, exists := agent.Timeouts[operation]; exists {
return t
}

return agent.DefaultTimeout
}

func getFinalContext(ctx context.Context, operation string, agent *Agent) (context.Context, context.CancelFunc) {
timeout := getFinalTimeout(operation, agent).Duration
if timeout == 0 {
return ctx, func() {}
}

return context.WithTimeout(ctx, timeout)
}

func initializeAgentRegistry(cs *ClientSet) (map[string]*Agent, error) {
agentRegistry := make(map[string]*Agent)
cfg := GetConfig()
var agentDeployments []*Agent

// Ensure that the old configuration is backward compatible
for taskType, agentID := range cfg.AgentForTaskTypes {
agentRegistry[taskType] = cfg.Agents[agentID]
}

if len(cfg.DefaultAgent.Endpoint) != 0 {
agentDeployments = append(agentDeployments, &cfg.DefaultAgent)
}

Check warning on line 99 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L98-L99

Added lines #L98 - L99 were not covered by tests
agentDeployments = append(agentDeployments, maps.Values(cfg.Agents)...)
for _, agentDeployment := range agentDeployments {
client := cs.agentMetadataClients[agentDeployment.Endpoint]

finalCtx, cancel := getFinalContext(context.Background(), "ListAgents", agentDeployment)
defer cancel()

res, err := client.ListAgents(finalCtx, &admin.ListAgentsRequest{})
if err != nil {
grpcStatus, ok := status.FromError(err)
if grpcStatus.Code() == codes.Unimplemented {
// we should not panic here, as we want to continue to support old agent settings
logger.Infof(context.Background(), "list agent method not implemented for agent: [%v]", agentDeployment)
continue

Check warning on line 113 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L109-L113

Added lines #L109 - L113 were not covered by tests
}

if !ok {
return nil, fmt.Errorf("failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment, err)
}

Check warning on line 118 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L116-L118

Added lines #L116 - L118 were not covered by tests

return nil, fmt.Errorf("failed to list agent: [%v] with error: [%v]", agentDeployment, err)

Check warning on line 120 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L120

Added line #L120 was not covered by tests
}

agents := res.GetAgents()
for _, agent := range agents {
supportedTaskTypes := agent.SupportedTaskTypes
for _, supportedTaskType := range supportedTaskTypes {
agentRegistry[supportedTaskType] = agentDeployment
}
}
}

return agentRegistry, nil
}

func initializeClients(ctx context.Context) (*ClientSet, error) {
agentClients := make(map[string]service.AsyncAgentServiceClient)
agentMetadataClients := make(map[string]service.AgentMetadataServiceClient)

var agentDeployments []*Agent
cfg := GetConfig()

if len(cfg.DefaultAgent.Endpoint) != 0 {
agentDeployments = append(agentDeployments, &cfg.DefaultAgent)
}

Check warning on line 144 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L143-L144

Added lines #L143 - L144 were not covered by tests
agentDeployments = append(agentDeployments, maps.Values(cfg.Agents)...)
for _, agentDeployment := range agentDeployments {
conn, err := getGrpcConnection(ctx, agentDeployment)
if err != nil {
return nil, err
}
agentClients[agentDeployment.Endpoint] = service.NewAsyncAgentServiceClient(conn)
agentMetadataClients[agentDeployment.Endpoint] = service.NewAgentMetadataServiceClient(conn)

Check warning on line 152 in flyteplugins/go/tasks/plugins/webapi/agent/client.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/plugins/webapi/agent/client.go#L147-L152

Added lines #L147 - L152 were not covered by tests
}

return &ClientSet{
agentClients: agentClients,
agentMetadataClients: agentMetadataClients,
}, nil
}
18 changes: 18 additions & 0 deletions flyteplugins/go/tasks/plugins/webapi/agent/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package agent

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func TestInitializeClients(t *testing.T) {
cfg := defaultConfig
ctx := context.Background()
err := SetConfig(&cfg)
assert.NoError(t, err)
cs, err := initializeClients(ctx)
assert.NoError(t, err)
assert.NotNil(t, cs)
}
Loading
Loading