diff --git a/cmd/gorse-in-one/main.go b/cmd/gorse-in-one/main.go index f9876bd99..52336bff6 100644 --- a/cmd/gorse-in-one/main.go +++ b/cmd/gorse-in-one/main.go @@ -72,7 +72,7 @@ var oneCommand = &cobra.Command{ conf.Database.CacheStore = "sqlite://cache.db" conf.Recommend.DataSource.PositiveFeedbackTypes = []string{"star", "like"} conf.Recommend.DataSource.ReadFeedbackTypes = []string{"read"} - if err := conf.Validate(true); err != nil { + if err := conf.Validate(); err != nil { log.Logger().Fatal("invalid config", zap.Error(err)) } diff --git a/config/config.go b/config/config.go index 9988aad02..d5db32424 100644 --- a/config/config.go +++ b/config/config.go @@ -25,6 +25,7 @@ import ( "sync" "time" + mapset "github.com/deckarep/golang-set/v2" "github.com/expr-lang/expr/parser" "github.com/go-playground/locales/en" ut "github.com/go-playground/universal-translator" @@ -591,7 +592,7 @@ func LoadConfig(path string, oneModel bool) (*Config, error) { } // validate config file - if err := conf.Validate(oneModel); err != nil { + if err := conf.Validate(); err != nil { return nil, errors.Trace(err) } @@ -605,7 +606,25 @@ func LoadConfig(path string, oneModel bool) (*Config, error) { return &conf, nil } -func (config *Config) Validate(oneModel bool) error { +func (config *Config) Validate() error { + // Check non-personalized recommenders + nonPersonalizedNames := mapset.NewSet[string]() + for _, nonPersonalized := range config.Recommend.NonPersonalized { + if nonPersonalizedNames.Contains(nonPersonalized.Name) { + return errors.Errorf("non-personalized recommender %v is duplicated", nonPersonalized.Name) + } + nonPersonalizedNames.Add(nonPersonalized.Name) + } + + // Check item-to-item recommenders + itemToItemNames := mapset.NewSet[string]() + for _, itemToItem := range config.Recommend.ItemToItem { + if itemToItemNames.Contains(itemToItem.Name) { + return errors.Errorf("item-to-item recommender %v is duplicated", itemToItem.Name) + } + itemToItemNames.Add(itemToItem.Name) + } + validate := validator.New() if err := validate.RegisterValidation("data_store", func(fl validator.FieldLevel) bool { prefixes := []string{ @@ -649,6 +668,10 @@ func (config *Config) Validate(oneModel bool) error { return errors.Trace(err) } if err := validate.RegisterValidation("item_expr", func(fl validator.FieldLevel) bool { + if fl.Field().String() == "" { + // Empty expression is legal. + return true + } _, err := parser.Parse(fl.Field().String()) return err == nil }); err != nil { diff --git a/config/config_test.go b/config/config_test.go index 3d474e642..414646f21 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -26,6 +26,7 @@ import ( "github.com/sclevine/yj/convert" "github.com/spf13/viper" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) func TestUnmarshal(t *testing.T) { @@ -443,3 +444,40 @@ func TestItemToItemConfig_Hash(t *testing.T) { b = ItemToItemConfig{Column: "b"} assert.NotEqual(t, a.Hash(), b.Hash()) } + +type ValidateTestSuite struct { + suite.Suite + *Config +} + +func (s *ValidateTestSuite) SetupTest() { + s.Config = GetDefaultConfig() + s.Database.CacheStore = "redis://localhost:6379/0" + s.Database.DataStore = "mysql://gorse:gorse_pass@tcp(localhost:3306)/gorse" +} + +func (s *ValidateTestSuite) TestDuplicateNonPersonalized() { + s.Recommend.NonPersonalized = []NonPersonalizedConfig{{ + Name: "most_starred_weekly", + Score: "count(feedback, .FeedbackType == 'star')", + }, { + Name: "most_starred_weekly", + Score: "count(feedback, .FeedbackType == 'star')", + }} + s.Error(s.Validate()) +} + +func (s *ValidateTestSuite) TestDuplicateItemToItem() { + s.Recommend.ItemToItem = []ItemToItemConfig{{ + Name: "item_to_item", + Type: "users", + }, { + Name: "item_to_item", + Type: "users", + }} + s.Error(s.Validate()) +} + +func TestValidate(t *testing.T) { + suite.Run(t, new(ValidateTestSuite)) +}