diff --git a/workflow/controller/operator.go b/workflow/controller/operator.go index d18ce328ab9a..96cab8df89a3 100644 --- a/workflow/controller/operator.go +++ b/workflow/controller/operator.go @@ -3285,44 +3285,66 @@ func (woc *wfOperationCtx) processAggregateNodeOutputs(scope *wfScope, prefix st // Adding per-output aggregated value placeholders for outputName, valueList := range outputParamValueLists { key = fmt.Sprintf("%s.outputs.parameters.%s", prefix, outputName) - unmarshalSuccess := true - var unmarshalledList []interface{} - for _, value := range valueList { - // Only try to unmarshal things that look like json lists or dicts - // and especially avoid unstringified numbers which are valid JSON - valueTrim := strings.Trim(value, " ") - valueTrimLen := len(valueTrim) - if valueTrimLen > 0 && - !((valueTrim[0] == '{' && valueTrim[valueTrimLen-1] == '}') || - (valueTrim[0] == '[' && valueTrim[valueTrimLen-1] == ']')) { - unmarshalSuccess = false - break // This isn't a json list or dict, leave it - } - var unmarshalledValue interface{} - err := json.Unmarshal([]byte(value), &unmarshalledValue) - if err != nil { - unmarshalSuccess = false - break // Unmarshal failed, fall back to strings - } - unmarshalledList = append(unmarshalledList, unmarshalledValue) - } - var valueListJSON []byte - if unmarshalSuccess { - valueListJSON, err = json.Marshal(unmarshalledList) - if err != nil { - return err - } - } else { - valueListJSON, err = json.Marshal(valueList) - if err != nil { - return err - } + valueListJson, err := aggregatedJsonValueList(valueList) + if err != nil { + return err } - scope.addParamToScope(key, string(valueListJSON)) + scope.addParamToScope(key, valueListJson) } return nil } +// tryJsonUnmarshal unmarshals each item in the list assuming it is +// JSON and NOT a plain JSON value. +// If returns success only if all items can be unmarshalled and are either +// maps or lists +func tryJsonUnmarshal(valueList []string) ([]interface{}, bool) { + success := true + var list []interface{} + for _, value := range valueList { + var unmarshalledValue interface{} + err := json.Unmarshal([]byte(value), &unmarshalledValue) + if err != nil { + success = false + break // Unmarshal failed, fall back to strings + } + switch unmarshalledValue.(type) { + case []interface{}: + case map[string]interface{}: + // Keep these types + default: + // Drop anything else + success = false + } + if !success { + break + } + list = append(list, unmarshalledValue) + } + return list, success +} + +// aggregatedJsonValueList returns a string containing a JSON list, holding +// all of the values from the valueList. +// It tries to understand what's wanted from inner JSON using tryJsonUnmarshall +func aggregatedJsonValueList(valueList []string) (string, error) { + unmarshalledList, success := tryJsonUnmarshal(valueList) + var valueListJSON []byte + var err error + if success { + valueListJSON, err = json.Marshal(unmarshalledList) + if err != nil { + return "", err + } + } else { + valueListJSON, err = json.Marshal(valueList) + if err != nil { + return "", err + } + } + return string(valueListJSON), nil +} + // addParamToGlobalScope exports any desired node outputs to the global scope, and adds it to the global outputs. func (woc *wfOperationCtx) addParamToGlobalScope(param wfv1.Parameter) { if param.GlobalName == "" { diff --git a/workflow/controller/operator_aggregation_test.go b/workflow/controller/operator_aggregation_test.go new file mode 100644 index 000000000000..b966c448a6d3 --- /dev/null +++ b/workflow/controller/operator_aggregation_test.go @@ -0,0 +1,64 @@ +package controller + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTryJsonUnmarshal(t *testing.T) { + for _, testcase := range []struct { + input []string + success bool + expected []interface{} + }{ + {[]string{"1"}, false, nil}, + {[]string{"1", "2"}, false, nil}, + {[]string{"foo"}, false, nil}, + {[]string{"foo", "bar"}, false, nil}, + {[]string{`["1"]`, "2"}, false, nil}, // Fails on second element + {[]string{`{"foo":"1"}`, "2"}, false, nil}, // Fails on second element + {[]string{`["1"]`, `["2"]`}, true, []interface{}{[]interface{}{"1"}, []interface{}{"2"}}}, + {[]string{`["1"]`, `["2"]`}, true, []interface{}{[]interface{}{"1"}, []interface{}{"2"}}}, + {[]string{"\n[\"1\"] \n", "\t[\"2\"]\t"}, true, []interface{}{[]interface{}{"1"}, []interface{}{"2"}}}, + {[]string{`{"number":"1"}`, `{"number":"2"}`}, true, []interface{}{map[string]interface{}{"number": "1"}, map[string]interface{}{"number": "2"}}}, + {[]string{`[{"foo":"apple", "bar":"pear"}]`, `{"foo":"banana"}`}, true, []interface{}{[]interface{}{map[string]interface{}{"bar": "pear", "foo": "apple"}}, map[string]interface{}{"foo": "banana"}}}, + } { + t.Run(fmt.Sprintf("Unmarshal %v", testcase.input), + func(t *testing.T) { + list, success := tryJsonUnmarshal(testcase.input) + require.Equal(t, testcase.success, success) + if success { + assert.Equal(t, testcase.expected, list) + } + }) + } +} + +func TestAggregatedJsonValueList(t *testing.T) { + for _, testcase := range []struct { + input []string + expected string + }{ + {[]string{"1"}, `["1"]`}, + {[]string{"1", "2"}, `["1","2"]`}, + {[]string{"foo"}, `["foo"]`}, + {[]string{"foo", "bar"}, `["foo","bar"]`}, + {[]string{`["1"]`, "2"}, `["[\"1\"]","2"]`}, // This is expected, but not really useful + {[]string{`{"foo":"1"}`, "2"}, `["{\"foo\":\"1\"}","2"]`}, // This is expected, but not really useful + {[]string{`["1"]`, `["2"]`}, `[["1"],["2"]]`}, + {[]string{` ["1"]`, `["2"] `}, `[["1"],["2"]]`}, + {[]string{"\n[\"1\"] \n", "\t[\"2\"]\t"}, `[["1"],["2"]]`}, + {[]string{`{"number":"1"}`, `{"number":"2"}`}, `[{"number":"1"},{"number":"2"}]`}, + {[]string{`[{"foo":"apple", "bar":"pear"}]`}, `[[{"bar":"pear","foo":"apple"}]]`}, // Sorted map keys here may make this a fragile test, can be dropped + } { + t.Run(fmt.Sprintf("Aggregate %v", testcase.input), + func(t *testing.T) { + result, err := aggregatedJsonValueList(testcase.input) + require.NoError(t, err) + assert.Equal(t, testcase.expected, result) + }) + } +}