diff --git a/cmd/dagu.go b/cmd/dagu.go index c9a169ebd..e43dcd21c 100644 --- a/cmd/dagu.go +++ b/cmd/dagu.go @@ -16,10 +16,17 @@ import ( ) var ( - version = "0.0.0" - stdin io.ReadCloser - sigs chan os.Signal - globalConfig *admin.Config + version = "0.0.0" + stdin io.ReadCloser + sigs chan os.Signal + globalFlags = []cli.Flag{ + &cli.StringFlag{ + Name: "config", + Usage: "Admin config", + Value: "", + Required: false, + }, + } ) func main() { @@ -30,10 +37,31 @@ func main() { } } -func loadDAG(dagPath, params string) (cfg *config.Config, err error) { - cl := &config.Loader{BaseConfig: globalConfig.BaseConfig} - cfg, err = cl.Load(dagPath, params) - return +func loadGlobalConfig(c *cli.Context) (cfg *admin.Config, err error) { + l := &admin.Loader{} + cfgFile := c.String("config") + if cfgFile == "" { + cfgFile = settings.MustGet(settings.SETTING__ADMIN_CONFIG) + } + cfg, err = l.LoadAdminConfig(cfgFile) + if err == admin.ErrConfigNotFound { + cfg = admin.DefaultConfig() + err = nil + } + if err != nil { + return nil, fmt.Errorf("loading admin config failed: %w", err) + } + return cfg, err +} + +func loadDAG(c *cli.Context, dagPath, params string) (dag *config.Config, err error) { + cfg, err := loadGlobalConfig(c) + if err != nil { + return nil, err + } + cl := &config.Loader{BaseConfig: cfg.BaseConfig} + dag, err = cl.Load(dagPath, params) + return dag, err } func listenSignals(abortFunc func(sig os.Signal)) { @@ -72,18 +100,5 @@ func makeApp() *cli.App { newSchedulerCommand(), newVersionCommand(), }, - Before: func(c *cli.Context) error { - l := &admin.Loader{} - cfg, err := l.LoadAdminConfig(settings.MustGet(settings.SETTING__ADMIN_CONFIG)) - if err == admin.ErrConfigNotFound { - cfg = admin.DefaultConfig() - err = nil - } - if err != nil { - return fmt.Errorf("loading admin config failed: %w", err) - } - globalConfig = cfg - return nil - }, } } diff --git a/cmd/dagu_test.go b/cmd/dagu_test.go index 776a77f04..e73bd318d 100644 --- a/cmd/dagu_test.go +++ b/cmd/dagu_test.go @@ -19,6 +19,7 @@ type appTest struct { args []string errored bool output []string + errMessage []string exactOutput string stdin io.ReadCloser } @@ -77,6 +78,12 @@ func runAppTestOutput(app *cli.App, test appTest, t *testing.T) { return } + if err != nil && len(test.errMessage) > 0 { + for _, v := range test.errMessage { + require.Contains(t, err.Error(), v) + } + } + var buf bytes.Buffer _, err = io.Copy(&buf, r) require.NoError(t, err) diff --git a/cmd/dry.go b/cmd/dry.go index 41762ab87..6398bcdb1 100644 --- a/cmd/dry.go +++ b/cmd/dry.go @@ -12,16 +12,17 @@ func newDryCommand() *cli.Command { return &cli.Command{ Name: "dry", Usage: "dagu dry [--params=\"\"] ", - Flags: []cli.Flag{ + Flags: append( + globalFlags, &cli.StringFlag{ Name: "params", Usage: "parameters", Value: "", Required: false, }, - }, + ), Action: func(c *cli.Context) error { - cfg, err := loadDAG(c.Args().Get(0), c.String("params")) + cfg, err := loadDAG(c, c.Args().Get(0), c.String("params")) if err != nil { return err } diff --git a/cmd/retry.go b/cmd/retry.go index 792f73337..59299b5a3 100644 --- a/cmd/retry.go +++ b/cmd/retry.go @@ -5,7 +5,9 @@ import ( "path/filepath" "github.com/yohamta/dagu" + "github.com/yohamta/dagu/internal/config" "github.com/yohamta/dagu/internal/database" + "github.com/yohamta/dagu/internal/models" "github.com/urfave/cli/v2" ) @@ -13,36 +15,31 @@ import ( func newRetryCommand() *cli.Command { return &cli.Command{ Name: "retry", - Usage: "dagu retry --req= ", - Flags: []cli.Flag{ + Usage: "dagu retry --req= ", + Flags: append( + globalFlags, &cli.StringFlag{ Name: "req", Usage: "request-id", Value: "", Required: true, }, - }, + ), Action: func(c *cli.Context) error { f, _ := filepath.Abs(c.Args().Get(0)) + db := database.Database{Config: database.DefaultConfig()} requestId := c.String("req") - return retry(f, requestId) + status, err := db.FindByRequestId(f, requestId) + if err != nil { + return err + } + cfg, err := loadDAG(c, c.Args().Get(0), status.Status.Params) + return retry(cfg, status) }, } } -func retry(f, requestId string) error { - db := database.Database{ - Config: database.DefaultConfig(), - } - status, err := db.FindByRequestId(f, requestId) - if err != nil { - return err - } - cfg, err := loadDAG(f, status.Status.Params) - if err != nil { - return err - } - +func retry(cfg *config.Config, status *models.StatusFile) error { a := &dagu.Agent{ AgentConfig: &dagu.AgentConfig{ DAG: cfg, diff --git a/cmd/retry_test.go b/cmd/retry_test.go index 7e4558499..d33675c71 100644 --- a/cmd/retry_test.go +++ b/cmd/retry_test.go @@ -64,6 +64,10 @@ func Test_retryCommand(t *testing.T) { } func Test_retryFail(t *testing.T) { - configPath := testConfig("cmd_retry.yaml") - require.Error(t, retry(configPath, "invalid-request-id")) + app := makeApp() + runAppTestOutput(app, appTest{ + args: []string{"", "retry", fmt.Sprintf("--req=%s", + "invalid-request-id"), testConfig("cmd_retry.yaml")}, errored: true, + errMessage: []string{"request id not found"}, + }, t) } diff --git a/cmd/scheduler.go b/cmd/scheduler.go index 2f406b27e..296d3eef0 100644 --- a/cmd/scheduler.go +++ b/cmd/scheduler.go @@ -12,20 +12,25 @@ func newSchedulerCommand() *cli.Command { return &cli.Command{ Name: "scheduler", Usage: "dagu scheduler", - Flags: []cli.Flag{ + Flags: append( + globalFlags, &cli.StringFlag{ Name: "dags", Usage: "DAGs directory", Value: "", Required: false, }, - }, + ), Action: func(c *cli.Context) error { + cfg, err := loadGlobalConfig(c) + if err != nil { + return err + } dagsDir := c.String("dags") if dagsDir != "" { - globalConfig.DAGs = dagsDir + cfg.DAGs = dagsDir } - return startScheduler(globalConfig) + return startScheduler(cfg) }, } } diff --git a/cmd/server.go b/cmd/server.go index 580794239..ae02ed221 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -12,8 +12,13 @@ func newServerCommand() *cli.Command { return &cli.Command{ Name: "server", Usage: "dagu server", + Flags: globalFlags, Action: func(c *cli.Context) error { - return startServer(globalConfig) + cfg, err := loadGlobalConfig(c) + if err != nil { + return err + } + return startServer(cfg) }, } } diff --git a/cmd/start.go b/cmd/start.go index d5a772179..8a8d0bbd8 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -11,18 +11,18 @@ import ( func newStartCommand() *cli.Command { return &cli.Command{ Name: "start", - Usage: "dagu start [--params=\"\"] ", - Flags: []cli.Flag{ + Usage: "dagu start [--params=\"\"] ", + Flags: append( + globalFlags, &cli.StringFlag{ Name: "params", Usage: "parameters", Value: "", Required: false, }, - }, + ), Action: func(c *cli.Context) error { - config_file_path := c.Args().Get(0) - cfg, err := loadDAG(config_file_path, c.String("params")) + cfg, err := loadDAG(c, c.Args().Get(0), c.String("params")) if err != nil { return err } diff --git a/cmd/start_test.go b/cmd/start_test.go index d2614bc66..85eead9bb 100644 --- a/cmd/start_test.go +++ b/cmd/start_test.go @@ -1,6 +1,8 @@ package main import ( + "fmt" + "os" "testing" ) @@ -22,8 +24,17 @@ func Test_startCommand(t *testing.T) { args: []string{"", "start", testConfig("cmd_start_success")}, errored: false, output: []string{"1 finished"}, }, + { + args: []string{"", "start", + fmt.Sprintf("--config=%s", testConfig("cmd_start_global_config.yaml")), + testConfig("cmd_start_global_config_check.yaml")}, errored: false, + output: []string{"GLOBAL_ENV_VAR"}, + }, } + // For testing --config parameter we need to set the environment variable for now. + os.Setenv("TEST_CONFIG_BASE", testsDir) + for _, v := range tests { app := makeApp() runAppTestOutput(app, v, t) diff --git a/cmd/status.go b/cmd/status.go index 33ac86549..3800da2ff 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -13,9 +13,10 @@ import ( func newStatusCommand() *cli.Command { return &cli.Command{ Name: "status", - Usage: "dagu status ", + Usage: "dagu status ", + Flags: globalFlags, Action: func(c *cli.Context) error { - cfg, err := loadDAG(c.Args().Get(0), "") + cfg, err := loadDAG(c, c.Args().Get(0), "") if err != nil { return err } diff --git a/cmd/stop.go b/cmd/stop.go index 797db8e8d..080c06702 100644 --- a/cmd/stop.go +++ b/cmd/stop.go @@ -11,9 +11,10 @@ import ( func newStopCommand() *cli.Command { return &cli.Command{ Name: "stop", - Usage: "dagu stop ", + Usage: "dagu stop ", + Flags: globalFlags, Action: func(c *cli.Context) error { - cfg, err := loadDAG(c.Args().Get(0), "") + cfg, err := loadDAG(c, c.Args().Get(0), "") if err != nil { return err } diff --git a/internal/admin/config.go b/internal/admin/config.go index 0ac90d0b7..f34a0dcf3 100644 --- a/internal/admin/config.go +++ b/internal/admin/config.go @@ -110,7 +110,10 @@ func buildDAGsDir(cfg *Config, def *configDefinition) (err error) { } func buildBaseConfig(cfg *Config, def *configDefinition) (err error) { - cfg.BaseConfig = strings.TrimSpace(def.BaseConfig) + cfg.BaseConfig, err = utils.ParseVariable(strings.TrimSpace(def.BaseConfig)) + if err != nil { + return err + } if cfg.BaseConfig == "" { cfg.BaseConfig = settings.MustGet(settings.SETTING__BASE_CONFIG) } diff --git a/tests/testdata/cmd_start_global_config.yaml b/tests/testdata/cmd_start_global_config.yaml new file mode 100644 index 000000000..e3b7af3e7 --- /dev/null +++ b/tests/testdata/cmd_start_global_config.yaml @@ -0,0 +1 @@ +baseConfig: ${TEST_CONFIG_BASE}/cmd_start_global_config_base.yaml \ No newline at end of file diff --git a/tests/testdata/cmd_start_global_config_base.yaml b/tests/testdata/cmd_start_global_config_base.yaml new file mode 100644 index 000000000..f2c983182 --- /dev/null +++ b/tests/testdata/cmd_start_global_config_base.yaml @@ -0,0 +1,2 @@ +env: + - GLOBAL_ENV: "GLOBAL_ENV_VAR" \ No newline at end of file diff --git a/tests/testdata/cmd_start_global_config_check.yaml b/tests/testdata/cmd_start_global_config_check.yaml new file mode 100644 index 000000000..082107c82 --- /dev/null +++ b/tests/testdata/cmd_start_global_config_check.yaml @@ -0,0 +1,3 @@ +steps: + - name: "1" + command: "echo ${GLOBAL_ENV}" \ No newline at end of file