From 93c2fb862aa950f2349664a817611c9ede779d1b Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Mon, 7 Oct 2024 11:06:59 -0400 Subject: [PATCH] Fix a bug where schema validation looses type information if the input has an any in it (#827) --- pkg/workflows/sdk/compute_generated.go | 70 +++++++++++++++------- pkg/workflows/sdk/gen/compute.go.tmpl | 7 ++- pkg/workflows/sdk/testutils/runner_test.go | 35 +++++++++++ 3 files changed, 90 insertions(+), 22 deletions(-) diff --git a/pkg/workflows/sdk/compute_generated.go b/pkg/workflows/sdk/compute_generated.go index 5772957b9..2b165573f 100644 --- a/pkg/workflows/sdk/compute_generated.go +++ b/pkg/workflows/sdk/compute_generated.go @@ -43,12 +43,15 @@ func Compute1[I0 any, O any](w *WorkflowSpecFactory, ref string, input Compute1I return capabilities.CapabilityResponse{}, err } - // verify against any schema + // verify against any schema by marshalling and unmarshalling ji, err := json.Marshal(inputs) if err != nil { return capabilities.CapabilityResponse{}, err } - if err := json.Unmarshal(ji, &inputs); err != nil { + + // use a temp variable to unmarshal to avoid type loss if the inputs has an any in it + var tmp runtime1Inputs[I0] + if err := json.Unmarshal(ji, &tmp); err != nil { return capabilities.CapabilityResponse{}, err } @@ -110,12 +113,15 @@ func Compute2[I0 any, I1 any, O any](w *WorkflowSpecFactory, ref string, input C return capabilities.CapabilityResponse{}, err } - // verify against any schema + // verify against any schema by marshalling and unmarshalling ji, err := json.Marshal(inputs) if err != nil { return capabilities.CapabilityResponse{}, err } - if err := json.Unmarshal(ji, &inputs); err != nil { + + // use a temp variable to unmarshal to avoid type loss if the inputs has an any in it + var tmp runtime2Inputs[I0, I1] + if err := json.Unmarshal(ji, &tmp); err != nil { return capabilities.CapabilityResponse{}, err } @@ -180,12 +186,15 @@ func Compute3[I0 any, I1 any, I2 any, O any](w *WorkflowSpecFactory, ref string, return capabilities.CapabilityResponse{}, err } - // verify against any schema + // verify against any schema by marshalling and unmarshalling ji, err := json.Marshal(inputs) if err != nil { return capabilities.CapabilityResponse{}, err } - if err := json.Unmarshal(ji, &inputs); err != nil { + + // use a temp variable to unmarshal to avoid type loss if the inputs has an any in it + var tmp runtime3Inputs[I0, I1, I2] + if err := json.Unmarshal(ji, &tmp); err != nil { return capabilities.CapabilityResponse{}, err } @@ -253,12 +262,15 @@ func Compute4[I0 any, I1 any, I2 any, I3 any, O any](w *WorkflowSpecFactory, ref return capabilities.CapabilityResponse{}, err } - // verify against any schema + // verify against any schema by marshalling and unmarshalling ji, err := json.Marshal(inputs) if err != nil { return capabilities.CapabilityResponse{}, err } - if err := json.Unmarshal(ji, &inputs); err != nil { + + // use a temp variable to unmarshal to avoid type loss if the inputs has an any in it + var tmp runtime4Inputs[I0, I1, I2, I3] + if err := json.Unmarshal(ji, &tmp); err != nil { return capabilities.CapabilityResponse{}, err } @@ -329,12 +341,15 @@ func Compute5[I0 any, I1 any, I2 any, I3 any, I4 any, O any](w *WorkflowSpecFact return capabilities.CapabilityResponse{}, err } - // verify against any schema + // verify against any schema by marshalling and unmarshalling ji, err := json.Marshal(inputs) if err != nil { return capabilities.CapabilityResponse{}, err } - if err := json.Unmarshal(ji, &inputs); err != nil { + + // use a temp variable to unmarshal to avoid type loss if the inputs has an any in it + var tmp runtime5Inputs[I0, I1, I2, I3, I4] + if err := json.Unmarshal(ji, &tmp); err != nil { return capabilities.CapabilityResponse{}, err } @@ -408,12 +423,15 @@ func Compute6[I0 any, I1 any, I2 any, I3 any, I4 any, I5 any, O any](w *Workflow return capabilities.CapabilityResponse{}, err } - // verify against any schema + // verify against any schema by marshalling and unmarshalling ji, err := json.Marshal(inputs) if err != nil { return capabilities.CapabilityResponse{}, err } - if err := json.Unmarshal(ji, &inputs); err != nil { + + // use a temp variable to unmarshal to avoid type loss if the inputs has an any in it + var tmp runtime6Inputs[I0, I1, I2, I3, I4, I5] + if err := json.Unmarshal(ji, &tmp); err != nil { return capabilities.CapabilityResponse{}, err } @@ -490,12 +508,15 @@ func Compute7[I0 any, I1 any, I2 any, I3 any, I4 any, I5 any, I6 any, O any](w * return capabilities.CapabilityResponse{}, err } - // verify against any schema + // verify against any schema by marshalling and unmarshalling ji, err := json.Marshal(inputs) if err != nil { return capabilities.CapabilityResponse{}, err } - if err := json.Unmarshal(ji, &inputs); err != nil { + + // use a temp variable to unmarshal to avoid type loss if the inputs has an any in it + var tmp runtime7Inputs[I0, I1, I2, I3, I4, I5, I6] + if err := json.Unmarshal(ji, &tmp); err != nil { return capabilities.CapabilityResponse{}, err } @@ -575,12 +596,15 @@ func Compute8[I0 any, I1 any, I2 any, I3 any, I4 any, I5 any, I6 any, I7 any, O return capabilities.CapabilityResponse{}, err } - // verify against any schema + // verify against any schema by marshalling and unmarshalling ji, err := json.Marshal(inputs) if err != nil { return capabilities.CapabilityResponse{}, err } - if err := json.Unmarshal(ji, &inputs); err != nil { + + // use a temp variable to unmarshal to avoid type loss if the inputs has an any in it + var tmp runtime8Inputs[I0, I1, I2, I3, I4, I5, I6, I7] + if err := json.Unmarshal(ji, &tmp); err != nil { return capabilities.CapabilityResponse{}, err } @@ -663,12 +687,15 @@ func Compute9[I0 any, I1 any, I2 any, I3 any, I4 any, I5 any, I6 any, I7 any, I8 return capabilities.CapabilityResponse{}, err } - // verify against any schema + // verify against any schema by marshalling and unmarshalling ji, err := json.Marshal(inputs) if err != nil { return capabilities.CapabilityResponse{}, err } - if err := json.Unmarshal(ji, &inputs); err != nil { + + // use a temp variable to unmarshal to avoid type loss if the inputs has an any in it + var tmp runtime9Inputs[I0, I1, I2, I3, I4, I5, I6, I7, I8] + if err := json.Unmarshal(ji, &tmp); err != nil { return capabilities.CapabilityResponse{}, err } @@ -754,12 +781,15 @@ func Compute10[I0 any, I1 any, I2 any, I3 any, I4 any, I5 any, I6 any, I7 any, I return capabilities.CapabilityResponse{}, err } - // verify against any schema + // verify against any schema by marshalling and unmarshalling ji, err := json.Marshal(inputs) if err != nil { return capabilities.CapabilityResponse{}, err } - if err := json.Unmarshal(ji, &inputs); err != nil { + + // use a temp variable to unmarshal to avoid type loss if the inputs has an any in it + var tmp runtime10Inputs[I0, I1, I2, I3, I4, I5, I6, I7, I8, I9] + if err := json.Unmarshal(ji, &tmp); err != nil { return capabilities.CapabilityResponse{}, err } diff --git a/pkg/workflows/sdk/gen/compute.go.tmpl b/pkg/workflows/sdk/gen/compute.go.tmpl index e2f7a22d3..c944ca2fa 100644 --- a/pkg/workflows/sdk/gen/compute.go.tmpl +++ b/pkg/workflows/sdk/gen/compute.go.tmpl @@ -48,12 +48,15 @@ func Compute{{.}}[{{range RangeNum .}}I{{.}} any, {{ end }}O any](w *WorkflowSp return capabilities.CapabilityResponse{}, err } - // verify against any schema + // verify against any schema by marshalling and unmarshalling ji, err := json.Marshal(inputs) if err != nil { return capabilities.CapabilityResponse{}, err } - if err := json.Unmarshal(ji, &inputs); err != nil { + + // use a temp variable to unmarshal to avoid type loss if the inputs has an any in it + var tmp runtime{{.}}Inputs[{{range RangeNum . }}I{{.}},{{ end }}] + if err := json.Unmarshal(ji, &tmp); err != nil { return capabilities.CapabilityResponse{}, err } diff --git a/pkg/workflows/sdk/testutils/runner_test.go b/pkg/workflows/sdk/testutils/runner_test.go index 27d7a7609..61f7a99a9 100644 --- a/pkg/workflows/sdk/testutils/runner_test.go +++ b/pkg/workflows/sdk/testutils/runner_test.go @@ -3,7 +3,9 @@ package testutils_test import ( "context" "errors" + "fmt" "reflect" + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -253,6 +255,39 @@ func TestRunner(t *testing.T) { }) } +func TestCompute(t *testing.T) { + t.Run("Inputs don't loose integer types when any is deserialized to", func(t *testing.T) { + workflow := sdk.NewWorkflowSpecFactory(sdk.NewWorkflowParams{Name: "name", Owner: "owner"}) + trigger := basictrigger.TriggerConfig{Name: "foo", Number: 100}.New(workflow) + toMap := sdk.Compute1(workflow, "tomap", sdk.Compute1Inputs[string]{Arg0: trigger.CoolOutput()}, func(runtime sdk.Runtime, i0 string) (map[string]any, error) { + v, err := strconv.Atoi(i0) + if err != nil { + return nil, err + } + + return map[string]any{"a": int64(v)}, nil + }) + + sdk.Compute1(workflow, "compute", sdk.Compute1Inputs[map[string]any]{Arg0: toMap.Value()}, func(runtime sdk.Runtime, input map[string]any) (any, error) { + actual := input["a"] + if int64(100) != actual { + return nil, fmt.Errorf("expected uint64(100), got %v of type %T", actual, actual) + } + + return actual, nil + }) + + runner := testutils.NewRunner(tests.Context(t)) + basictriggertest.Trigger(runner, func() (basictrigger.TriggerOutputs, error) { + return basictrigger.TriggerOutputs{CoolOutput: "100"}, nil + }) + + runner.Run(workflow) + + require.NoError(t, runner.Err()) + }) +} + func registrationWorkflow() (*sdk.WorkflowSpecFactory, map[string]any, map[string]any) { workflow := sdk.NewWorkflowSpecFactory(sdk.NewWorkflowParams{Name: "tester", Owner: "ryan"}) testTriggerConfig := map[string]any{"something": "from nothing"}