diff --git a/plugin/filter/filter.go b/plugin/filter/filter.go index 145e8c9ecb90d..0a075b48be294 100644 --- a/plugin/filter/filter.go +++ b/plugin/filter/filter.go @@ -14,8 +14,8 @@ var MemoFilterCELAttributes = []cel.EnvOption{ // Parse parses the filter string and returns the parsed expression. // The filter string should be a CEL expression. -func Parse(filter string) (expr *exprv1.ParsedExpr, err error) { - e, err := cel.NewEnv(MemoFilterCELAttributes...) +func Parse(filter string, opts ...cel.EnvOption) (expr *exprv1.ParsedExpr, err error) { + e, err := cel.NewEnv(opts...) if err != nil { return nil, errors.Wrap(err, "failed to create CEL environment") } @@ -28,20 +28,23 @@ func Parse(filter string) (expr *exprv1.ParsedExpr, err error) { // GetConstValue returns the constant value of the expression. func GetConstValue(expr *exprv1.Expr) (any, error) { - switch v := expr.ExprKind.(type) { - case *exprv1.Expr_ConstExpr: - switch v.ConstExpr.ConstantKind.(type) { - case *exprv1.Constant_StringValue: - return v.ConstExpr.GetStringValue(), nil - case *exprv1.Constant_Int64Value: - return v.ConstExpr.GetInt64Value(), nil - case *exprv1.Constant_Uint64Value: - return v.ConstExpr.GetUint64Value(), nil - case *exprv1.Constant_DoubleValue: - return v.ConstExpr.GetDoubleValue(), nil - case *exprv1.Constant_BoolValue: - return v.ConstExpr.GetBoolValue(), nil - } + v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr) + if !ok { + return nil, errors.New("invalid constant expression") + } + + switch v.ConstExpr.ConstantKind.(type) { + case *exprv1.Constant_StringValue: + return v.ConstExpr.GetStringValue(), nil + case *exprv1.Constant_Int64Value: + return v.ConstExpr.GetInt64Value(), nil + case *exprv1.Constant_Uint64Value: + return v.ConstExpr.GetUint64Value(), nil + case *exprv1.Constant_DoubleValue: + return v.ConstExpr.GetDoubleValue(), nil + case *exprv1.Constant_BoolValue: + return v.ConstExpr.GetBoolValue(), nil + default: + return nil, errors.New("unexpected constant type") } - return nil, errors.New("invalid constant expression") } diff --git a/store/db/sqlite/memo_filter_test.go b/store/db/sqlite/memo_filter_test.go index 617b192c8a133..d5491ac9bceb7 100644 --- a/store/db/sqlite/memo_filter_test.go +++ b/store/db/sqlite/memo_filter_test.go @@ -31,7 +31,7 @@ func TestRestoreExprToSQL(t *testing.T) { } for _, tt := range tests { - parsedExpr, err := filter.Parse(tt.filter) + parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) require.NoError(t, err) result, err := RestoreExprToSQL(parsedExpr.GetExpr()) require.NoError(t, err)