diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go index a75afb157d..235c3fe7e3 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/client_test.go @@ -5,28 +5,10 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" ) -func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { - mockMetadataServiceClient := new(agentMocks.AgentMetadataServiceClient) - mockRequest := &admin.ListAgentsRequest{} - mockResponse := &admin.ListAgentsResponse{ - Agents: []*admin.Agent{ - { - Name: "test-agent", - SupportedTaskTypes: []string{"task1", "task2", "task3"}, - }, - }, - } - - mockMetadataServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) - return mockMetadataServiceClient -} - func mockGetBadAsyncClientFunc() *agentMocks.AsyncAgentServiceClient { return nil } diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go index 5943b36c78..d476fa2119 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go @@ -25,6 +25,7 @@ import ( pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" ioMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi" + agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" "github.com/flyteorg/flyte/flyteplugins/tests" "github.com/flyteorg/flyte/flytestdlib/contextutils" "github.com/flyteorg/flyte/flytestdlib/promutils" @@ -160,9 +161,8 @@ func TestEndToEnd(t *testing.T) { metricScope: iCtx.MetricsScope(), cfg: GetConfig(), cs: &ClientSet{ - agentClients: map[string]service.AsyncAgentServiceClient{ - "localhost:80": mockGetBadAsyncClientFunc(), - }, + agentClients: map[string]service.AsyncAgentServiceClient{}, + agentMetadataClients: map[string]service.AgentMetadataServiceClient{}, }, }, }, nil @@ -296,6 +296,44 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext { func newMockAgentPlugin() webapi.PluginEntry { + agentClient := new(agentMocks.AsyncAgentServiceClient) + // mockCreateRequest := &admin.CreateTaskRequest{} + // agentClient.On("CreateTask", mock.Anything, mockCreateRequest).Return( + // &admin.CreateTaskResponse{ + // Res: &admin.CreateTaskResponse_ResourceMeta{ + // ResourceMeta: []byte{1, 2, 3, 4}, + // }}, nil) + + agentClient.On("CreateTask", mock.Anything, mock.MatchedBy(func(req *admin.CreateTaskRequest) bool { + // Your custom logic to decide if the condition is met + expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"} + return slices.Equal(req.Template.GetContainer().Args, expectedArgs) + })).Run(func(args mock.Arguments) { + req := args.Get(1).(*admin.CreateTaskRequest) + // Extract the mock.Call object + call := args.Get(0).(mock.Call) + + if slices.Equal(req.Template.GetContainer().Args, []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"}) { + // If condition is met, return a specific response + call.Return(&admin.CreateTaskResponse{ + Res: &admin.CreateTaskResponse_ResourceMeta{ + ResourceMeta: []byte{1, 2, 3, 4}, + }, + }, nil) + } else { + // Else, return a different response or error + call.Return(nil, fmt.Errorf("unexpected arguments")) + } + }).Maybe() + + mockGetRequest := &admin.GetTaskRequest{} + agentClient.On("GetTask", mock.Anything, mockGetRequest).Return( + &admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil) + + mockDeleteRequest := &admin.DeleteTaskRequest{} + agentClient.On("DeleteTask", mock.Anything, mockDeleteRequest).Return( + &admin.DeleteTaskResponse{}, nil) + return webapi.PluginEntry{ ID: "agent-service", SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark_job", "api_task"}, @@ -306,7 +344,7 @@ func newMockAgentPlugin() webapi.PluginEntry { cfg: GetConfig(), cs: &ClientSet{ agentClients: map[string]service.AsyncAgentServiceClient{ - "": &MockAsyncTask{}, + "": agentClient, }, }, }, diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go index 5b2f0d4f51..8f5ce243af 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin.go @@ -85,7 +85,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR client := p.cs.agentClients[agent.Endpoint] if client == nil { - return nil, nil, fmt.Errorf("default agent:[%v] is not connected, please check if it is up and running", agent) + return nil, nil, fmt.Errorf("default agent is not connected, please check if endpoint:[%v] is up and running", agent.Endpoint) } finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent) diff --git a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go index 2ccd56a24b..d17adf5180 100644 --- a/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go +++ b/flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go @@ -6,10 +6,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "golang.org/x/exp/maps" - "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin" flyteIdlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" @@ -19,10 +15,14 @@ import ( pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" ioMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" webapiPlugin "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi/mocks" + agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks" "github.com/flyteorg/flyte/flyteplugins/tests" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/promutils" "github.com/flyteorg/flyte/flytestdlib/storage" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/exp/maps" ) func TestSyncTask(t *testing.T) { @@ -186,6 +186,22 @@ func TestPlugin(t *testing.T) { }) } +func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient { + mockMetadataServiceClient := new(agentMocks.AgentMetadataServiceClient) + mockRequest := &admin.ListAgentsRequest{} + mockResponse := &admin.ListAgentsResponse{ + Agents: []*admin.Agent{ + { + Name: "test-agent", + SupportedTaskTypes: []string{"task1", "task2", "task3"}, + }, + }, + } + + mockMetadataServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil) + return mockMetadataServiceClient +} + func TestInitializeAgentRegistry(t *testing.T) { agentClients := make(map[string]service.AsyncAgentServiceClient) agentMetadataClients := make(map[string]service.AgentMetadataServiceClient)