diff --git a/server/router/api/v1/user_service_shortcuts.go b/server/router/api/v1/user_service_shortcuts.go index 72d0a085bb143..775532f113f03 100644 --- a/server/router/api/v1/user_service_shortcuts.go +++ b/server/router/api/v1/user_service_shortcuts.go @@ -7,6 +7,7 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" + "github.com/pkg/errors" "github.com/usememos/memos/internal/util" "github.com/usememos/memos/plugin/filter" v1pb "github.com/usememos/memos/proto/gen/api/v1" @@ -78,10 +79,7 @@ func (s *APIV1Service) CreateShortcut(ctx context.Context, request *v1pb.CreateS if newShortcut.Title == "" { return nil, status.Errorf(codes.InvalidArgument, "title is required") } - if newShortcut.Filter == "" { - return nil, status.Errorf(codes.InvalidArgument, "filter is required") - } - if _, err := filter.Parse(newShortcut.Filter, filter.MemoFilterCELAttributes...); err != nil { + if err := s.validateFilter(ctx, newShortcut.Filter); err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err) } if request.ValidateOnly { @@ -171,11 +169,7 @@ func (s *APIV1Service) UpdateShortcut(ctx context.Context, request *v1pb.UpdateS } shortcut.Title = request.Shortcut.GetTitle() } else if field == "filter" { - if request.Shortcut.GetFilter() == "" { - return nil, status.Errorf(codes.InvalidArgument, "filter is required") - } - // Validate the filter. - if _, err := filter.Parse(request.Shortcut.GetFilter(), filter.MemoFilterCELAttributes...); err != nil { + if err := s.validateFilter(ctx, request.Shortcut.GetFilter()); err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err) } shortcut.Filter = request.Shortcut.GetFilter() @@ -244,3 +238,20 @@ func (s *APIV1Service) DeleteShortcut(ctx context.Context, request *v1pb.DeleteS return &emptypb.Empty{}, nil } + +func (s *APIV1Service) validateFilter(_ context.Context, filterStr string) error { + if filterStr == "" { + return errors.New("filter cannot be empty") + } + // Validate the filter. + parsedExpr, err := filter.Parse(filterStr, filter.MemoFilterCELAttributes...) + if err != nil { + return errors.Wrap(err, "failed to parse filter") + } + convertCtx := filter.NewConvertContext() + err = s.Store.GetDriver().ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + if err != nil { + return errors.Wrap(err, "failed to convert filter to SQL") + } + return nil +} diff --git a/store/db/mysql/memo.go b/store/db/mysql/memo.go index e519c26fc70bf..5b1b4667018ee 100644 --- a/store/db/mysql/memo.go +++ b/store/db/mysql/memo.go @@ -118,7 +118,7 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo } convertCtx := filter.NewConvertContext() // ConvertExprToSQL converts the parsed expression to a SQL condition string. - if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { return nil, err } where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String())) diff --git a/store/db/mysql/memo_filter.go b/store/db/mysql/memo_filter.go index 6622c0f5d1f50..2bf2ef1fc946c 100644 --- a/store/db/mysql/memo_filter.go +++ b/store/db/mysql/memo_filter.go @@ -12,7 +12,7 @@ import ( "github.com/usememos/memos/plugin/filter" ) -func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { +func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok { switch v.CallExpr.Function { case "_||_", "_&&_": @@ -22,7 +22,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { if _, err := ctx.Buffer.WriteString("("); err != nil { return err } - if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { return err } operator := "AND" @@ -32,7 +32,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { return err } - if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil { + if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil { return err } if _, err := ctx.Buffer.WriteString(")"); err != nil { @@ -45,7 +45,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { return err } - if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { return err } if _, err := ctx.Buffer.WriteString(")"); err != nil { diff --git a/store/db/mysql/memo_filter_test.go b/store/db/mysql/memo_filter_test.go index b5a6e71ef79c0..8162b3ce3c29e 100644 --- a/store/db/mysql/memo_filter_test.go +++ b/store/db/mysql/memo_filter_test.go @@ -52,10 +52,11 @@ func TestConvertExprToSQL(t *testing.T) { } for _, tt := range tests { + db := &DB{} parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) require.NoError(t, err) convertCtx := filter.NewConvertContext() - err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) require.NoError(t, err) require.Equal(t, tt.want, convertCtx.Buffer.String()) require.Equal(t, tt.args, convertCtx.Args) diff --git a/store/db/postgres/memo.go b/store/db/postgres/memo.go index f5e1560ccc375..242fe0409ac83 100644 --- a/store/db/postgres/memo.go +++ b/store/db/postgres/memo.go @@ -110,7 +110,7 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo convertCtx := filter.NewConvertContext() convertCtx.ArgsOffset = len(args) // ConvertExprToSQL converts the parsed expression to a SQL condition string. - if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { return nil, err } where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String())) diff --git a/store/db/postgres/memo_filter.go b/store/db/postgres/memo_filter.go index 9096dab79dd61..452e5761a8bcf 100644 --- a/store/db/postgres/memo_filter.go +++ b/store/db/postgres/memo_filter.go @@ -12,7 +12,7 @@ import ( "github.com/usememos/memos/plugin/filter" ) -func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { +func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok { switch v.CallExpr.Function { case "_||_", "_&&_": @@ -22,7 +22,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { if _, err := ctx.Buffer.WriteString("("); err != nil { return err } - if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { return err } operator := "AND" @@ -32,7 +32,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { return err } - if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil { + if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil { return err } if _, err := ctx.Buffer.WriteString(")"); err != nil { @@ -45,7 +45,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { return err } - if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { return err } if _, err := ctx.Buffer.WriteString(")"); err != nil { diff --git a/store/db/postgres/memo_filter_test.go b/store/db/postgres/memo_filter_test.go index c5deb5a1a9028..95fb626ba48a6 100644 --- a/store/db/postgres/memo_filter_test.go +++ b/store/db/postgres/memo_filter_test.go @@ -52,10 +52,11 @@ func TestRestoreExprToSQL(t *testing.T) { } for _, tt := range tests { + db := &DB{} parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) require.NoError(t, err) convertCtx := filter.NewConvertContext() - err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) require.NoError(t, err) require.Equal(t, tt.want, convertCtx.Buffer.String()) require.Equal(t, tt.args, convertCtx.Args) diff --git a/store/db/sqlite/memo.go b/store/db/sqlite/memo.go index 58db5c576c488..d1ea204adc10f 100644 --- a/store/db/sqlite/memo.go +++ b/store/db/sqlite/memo.go @@ -110,7 +110,7 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo } convertCtx := filter.NewConvertContext() // ConvertExprToSQL converts the parsed expression to a SQL condition string. - if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { + if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil { return nil, err } where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String())) diff --git a/store/db/sqlite/memo_filter.go b/store/db/sqlite/memo_filter.go index 496b759b45764..07d02a69a921b 100644 --- a/store/db/sqlite/memo_filter.go +++ b/store/db/sqlite/memo_filter.go @@ -12,7 +12,7 @@ import ( "github.com/usememos/memos/plugin/filter" ) -func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { +func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok { switch v.CallExpr.Function { case "_||_", "_&&_": @@ -22,7 +22,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { if _, err := ctx.Buffer.WriteString("("); err != nil { return err } - if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { return err } operator := "AND" @@ -32,7 +32,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil { return err } - if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil { + if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil { return err } if _, err := ctx.Buffer.WriteString(")"); err != nil { @@ -45,7 +45,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error { if _, err := ctx.Buffer.WriteString("NOT ("); err != nil { return err } - if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { + if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil { return err } if _, err := ctx.Buffer.WriteString(")"); err != nil { diff --git a/store/db/sqlite/memo_filter_test.go b/store/db/sqlite/memo_filter_test.go index ae5b44147200d..c78b5d0dd7b5d 100644 --- a/store/db/sqlite/memo_filter_test.go +++ b/store/db/sqlite/memo_filter_test.go @@ -57,10 +57,11 @@ func TestConvertExprToSQL(t *testing.T) { } for _, tt := range tests { + db := &DB{} parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...) require.NoError(t, err) convertCtx := filter.NewConvertContext() - err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) + err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()) require.NoError(t, err) require.Equal(t, tt.want, convertCtx.Buffer.String()) require.Equal(t, tt.args, convertCtx.Args) diff --git a/store/driver.go b/store/driver.go index 94bd5113e358f..603ab1d4242b5 100644 --- a/store/driver.go +++ b/store/driver.go @@ -3,6 +3,10 @@ package store import ( "context" "database/sql" + + exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/usememos/memos/plugin/filter" ) // Driver is an interface for store driver. @@ -73,4 +77,7 @@ type Driver interface { UpsertReaction(ctx context.Context, create *Reaction) (*Reaction, error) ListReactions(ctx context.Context, find *FindReaction) ([]*Reaction, error) DeleteReaction(ctx context.Context, delete *DeleteReaction) error + + // Shortcut related methods. + ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error } diff --git a/store/store.go b/store/store.go index 0bb3704d19dcc..030a0a22dca54 100644 --- a/store/store.go +++ b/store/store.go @@ -24,6 +24,10 @@ func New(driver Driver, profile *profile.Profile) *Store { } } +func (s *Store) GetDriver() Driver { + return s.driver +} + func (s *Store) Close() error { return s.driver.Close() }