Skip to content

Commit

Permalink
push change
Browse files Browse the repository at this point in the history
Signed-off-by: Future-Outlier <[email protected]>
  • Loading branch information
Future-Outlier committed Jan 26, 2024
1 parent cd533da commit 80cf33f
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 27 deletions.
18 changes: 0 additions & 18 deletions flyteplugins/go/tasks/plugins/webapi/agent/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,10 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks"
)

func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient {
mockMetadataServiceClient := new(agentMocks.AgentMetadataServiceClient)
mockRequest := &admin.ListAgentsRequest{}
mockResponse := &admin.ListAgentsResponse{
Agents: []*admin.Agent{
{
Name: "test-agent",
SupportedTaskTypes: []string{"task1", "task2", "task3"},
},
},
}

mockMetadataServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil)
return mockMetadataServiceClient
}

func mockGetBadAsyncClientFunc() *agentMocks.AsyncAgentServiceClient {
return nil
}
Expand Down
46 changes: 42 additions & 4 deletions flyteplugins/go/tasks/plugins/webapi/agent/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks"
ioMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi"
agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks"
"github.com/flyteorg/flyte/flyteplugins/tests"
"github.com/flyteorg/flyte/flytestdlib/contextutils"
"github.com/flyteorg/flyte/flytestdlib/promutils"
Expand Down Expand Up @@ -160,9 +161,8 @@ func TestEndToEnd(t *testing.T) {
metricScope: iCtx.MetricsScope(),
cfg: GetConfig(),
cs: &ClientSet{
agentClients: map[string]service.AsyncAgentServiceClient{
"localhost:80": mockGetBadAsyncClientFunc(),
},
agentClients: map[string]service.AsyncAgentServiceClient{},
agentMetadataClients: map[string]service.AgentMetadataServiceClient{},
},
},
}, nil
Expand Down Expand Up @@ -296,6 +296,44 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext {

func newMockAgentPlugin() webapi.PluginEntry {

agentClient := new(agentMocks.AsyncAgentServiceClient)
// mockCreateRequest := &admin.CreateTaskRequest{}
// agentClient.On("CreateTask", mock.Anything, mockCreateRequest).Return(
// &admin.CreateTaskResponse{
// Res: &admin.CreateTaskResponse_ResourceMeta{
// ResourceMeta: []byte{1, 2, 3, 4},
// }}, nil)

agentClient.On("CreateTask", mock.Anything, mock.MatchedBy(func(req *admin.CreateTaskRequest) bool {
// Your custom logic to decide if the condition is met
expectedArgs := []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"}
return slices.Equal(req.Template.GetContainer().Args, expectedArgs)
})).Run(func(args mock.Arguments) {
req := args.Get(1).(*admin.CreateTaskRequest)
// Extract the mock.Call object
call := args.Get(0).(mock.Call)

if slices.Equal(req.Template.GetContainer().Args, []string{"pyflyte-fast-execute", "--output-prefix", "fake://bucket/prefix/nhv"}) {
// If condition is met, return a specific response
call.Return(&admin.CreateTaskResponse{
Res: &admin.CreateTaskResponse_ResourceMeta{
ResourceMeta: []byte{1, 2, 3, 4},
},
}, nil)
} else {
// Else, return a different response or error
call.Return(nil, fmt.Errorf("unexpected arguments"))
}
}).Maybe()

mockGetRequest := &admin.GetTaskRequest{}
agentClient.On("GetTask", mock.Anything, mockGetRequest).Return(
&admin.GetTaskResponse{Resource: &admin.Resource{State: admin.State_SUCCEEDED}}, nil)

mockDeleteRequest := &admin.DeleteTaskRequest{}
agentClient.On("DeleteTask", mock.Anything, mockDeleteRequest).Return(
&admin.DeleteTaskResponse{}, nil)

return webapi.PluginEntry{
ID: "agent-service",
SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "spark_job", "api_task"},
Expand All @@ -306,7 +344,7 @@ func newMockAgentPlugin() webapi.PluginEntry {
cfg: GetConfig(),
cs: &ClientSet{
agentClients: map[string]service.AsyncAgentServiceClient{
"": &MockAsyncTask{},
"": agentClient,
},
},
},
Expand Down
2 changes: 1 addition & 1 deletion flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR

client := p.cs.agentClients[agent.Endpoint]
if client == nil {
return nil, nil, fmt.Errorf("default agent:[%v] is not connected, please check if it is up and running", agent)
return nil, nil, fmt.Errorf("default agent is not connected, please check if endpoint:[%v] is up and running", agent.Endpoint)
}

finalCtx, cancel := getFinalContext(ctx, "CreateTask", agent)
Expand Down
24 changes: 20 additions & 4 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/exp/maps"

"github.com/flyteorg/flyte/flyteidl/clients/go/coreutils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
flyteIdlCore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
Expand All @@ -19,10 +15,14 @@ import (
pluginCoreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks"
ioMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks"
webapiPlugin "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi/mocks"
agentMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/webapi/agent/mocks"
"github.com/flyteorg/flyte/flyteplugins/tests"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/promutils"
"github.com/flyteorg/flyte/flytestdlib/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/exp/maps"
)

func TestSyncTask(t *testing.T) {
Expand Down Expand Up @@ -186,6 +186,22 @@ func TestPlugin(t *testing.T) {
})
}

func getMockMetadataServiceClient() *agentMocks.AgentMetadataServiceClient {
mockMetadataServiceClient := new(agentMocks.AgentMetadataServiceClient)
mockRequest := &admin.ListAgentsRequest{}
mockResponse := &admin.ListAgentsResponse{
Agents: []*admin.Agent{
{
Name: "test-agent",
SupportedTaskTypes: []string{"task1", "task2", "task3"},
},
},
}

mockMetadataServiceClient.On("ListAgents", mock.Anything, mockRequest).Return(mockResponse, nil)
return mockMetadataServiceClient
}

func TestInitializeAgentRegistry(t *testing.T) {
agentClients := make(map[string]service.AsyncAgentServiceClient)
agentMetadataClients := make(map[string]service.AgentMetadataServiceClient)
Expand Down

0 comments on commit 80cf33f

Please sign in to comment.