Skip to content

Commit

Permalink
feat: validate shortcut's filter
Browse files Browse the repository at this point in the history
  • Loading branch information
johnnyjoygh committed Feb 3, 2025
1 parent 8f35086 commit a7ca634
Show file tree
Hide file tree
Showing 12 changed files with 52 additions and 27 deletions.
29 changes: 20 additions & 9 deletions server/router/api/v1/user_service_shortcuts.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"

"github.com/pkg/errors"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/filter"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
Expand Down Expand Up @@ -78,10 +79,7 @@ func (s *APIV1Service) CreateShortcut(ctx context.Context, request *v1pb.CreateS
if newShortcut.Title == "" {
return nil, status.Errorf(codes.InvalidArgument, "title is required")
}
if newShortcut.Filter == "" {
return nil, status.Errorf(codes.InvalidArgument, "filter is required")
}
if _, err := filter.Parse(newShortcut.Filter, filter.MemoFilterCELAttributes...); err != nil {
if err := s.validateFilter(ctx, newShortcut.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
if request.ValidateOnly {
Expand Down Expand Up @@ -171,11 +169,7 @@ func (s *APIV1Service) UpdateShortcut(ctx context.Context, request *v1pb.UpdateS
}
shortcut.Title = request.Shortcut.GetTitle()
} else if field == "filter" {
if request.Shortcut.GetFilter() == "" {
return nil, status.Errorf(codes.InvalidArgument, "filter is required")
}
// Validate the filter.
if _, err := filter.Parse(request.Shortcut.GetFilter(), filter.MemoFilterCELAttributes...); err != nil {
if err := s.validateFilter(ctx, request.Shortcut.GetFilter()); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
shortcut.Filter = request.Shortcut.GetFilter()
Expand Down Expand Up @@ -244,3 +238,20 @@ func (s *APIV1Service) DeleteShortcut(ctx context.Context, request *v1pb.DeleteS

return &emptypb.Empty{}, nil
}

func (s *APIV1Service) validateFilter(_ context.Context, filterStr string) error {
if filterStr == "" {
return errors.New("filter cannot be empty")
}
// Validate the filter.
parsedExpr, err := filter.Parse(filterStr, filter.MemoFilterCELAttributes...)
if err != nil {
return errors.Wrap(err, "failed to parse filter")
}
convertCtx := filter.NewConvertContext()
err = s.Store.GetDriver().ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
if err != nil {
return errors.Wrap(err, "failed to convert filter to SQL")
}
return nil
}
2 changes: 1 addition & 1 deletion store/db/mysql/memo.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
Expand Down
8 changes: 4 additions & 4 deletions store/db/mysql/memo_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/usememos/memos/plugin/filter"
)

func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
switch v.CallExpr.Function {
case "_||_", "_&&_":
Expand All @@ -22,7 +22,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
if _, err := ctx.Buffer.WriteString("("); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
return err
}
operator := "AND"
Expand All @@ -32,7 +32,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
Expand All @@ -45,7 +45,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
Expand Down
3 changes: 2 additions & 1 deletion store/db/mysql/memo_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ func TestConvertExprToSQL(t *testing.T) {
}

for _, tt := range tests {
db := &DB{}
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
Expand Down
2 changes: 1 addition & 1 deletion store/db/postgres/memo.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
convertCtx := filter.NewConvertContext()
convertCtx.ArgsOffset = len(args)
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
Expand Down
8 changes: 4 additions & 4 deletions store/db/postgres/memo_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/usememos/memos/plugin/filter"
)

func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
switch v.CallExpr.Function {
case "_||_", "_&&_":
Expand All @@ -22,7 +22,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
if _, err := ctx.Buffer.WriteString("("); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
return err
}
operator := "AND"
Expand All @@ -32,7 +32,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
Expand All @@ -45,7 +45,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
Expand Down
3 changes: 2 additions & 1 deletion store/db/postgres/memo_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ func TestRestoreExprToSQL(t *testing.T) {
}

for _, tt := range tests {
db := &DB{}
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
Expand Down
2 changes: 1 addition & 1 deletion store/db/sqlite/memo.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
if err := ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
if err := d.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
where = append(where, fmt.Sprintf("(%s)", convertCtx.Buffer.String()))
Expand Down
8 changes: 4 additions & 4 deletions store/db/sqlite/memo_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/usememos/memos/plugin/filter"
)

func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
switch v.CallExpr.Function {
case "_||_", "_&&_":
Expand All @@ -22,7 +22,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
if _, err := ctx.Buffer.WriteString("("); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
return err
}
operator := "AND"
Expand All @@ -32,7 +32,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[1]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
Expand All @@ -45,7 +45,7 @@ func ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error {
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
return err
}
if err := ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
if err := d.ConvertExprToSQL(ctx, v.CallExpr.Args[0]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
Expand Down
3 changes: 2 additions & 1 deletion store/db/sqlite/memo_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ func TestConvertExprToSQL(t *testing.T) {
}

for _, tt := range tests {
db := &DB{}
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
err = ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
err = db.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
Expand Down
7 changes: 7 additions & 0 deletions store/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ package store
import (
"context"
"database/sql"

exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"

"github.com/usememos/memos/plugin/filter"
)

// Driver is an interface for store driver.
Expand Down Expand Up @@ -73,4 +77,7 @@ type Driver interface {
UpsertReaction(ctx context.Context, create *Reaction) (*Reaction, error)
ListReactions(ctx context.Context, find *FindReaction) ([]*Reaction, error)
DeleteReaction(ctx context.Context, delete *DeleteReaction) error

// Shortcut related methods.
ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) error
}
4 changes: 4 additions & 0 deletions store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ func New(driver Driver, profile *profile.Profile) *Store {
}
}

func (s *Store) GetDriver() Driver {
return s.driver
}

func (s *Store) Close() error {
return s.driver.Close()
}

0 comments on commit a7ca634

Please sign in to comment.