Skip to content

add tests for config.go and resolve a few config bugs #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/sshproxy/sshproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func mainExitCode() int {
}
}
} else {
if config.Etcd.Mandatory {
if config.Etcd.Mandatory.(bool) {
log.Fatal("Etcd is mandatory but unavailable")
}
}
Expand Down
5 changes: 1 addition & 4 deletions config/sshproxy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,7 @@
# ENV1: /tmp/env
# ssh:
# args: ["-vvv", "-Y"]
# # If routes are specified, each specified route is fully overridden, not merged.
# routes:
# default:
# dest: [hostx]
# dest: [hostx]
# - match:
# - groups: [bar]
# groups: [baz]
Expand Down
81 changes: 71 additions & 10 deletions pkg/utils/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"os"
"regexp"
"slices"
"sort"
"strings"
"time"

Expand Down Expand Up @@ -82,13 +83,15 @@ type sshConfig struct {
Args []string `yaml:",flow,omitempty"`
}

// We use interface{} instead of real type to check if the option was specified
// or not.
type etcdConfig struct {
Endpoints []string `yaml:",flow"`
TLS etcdTLSConfig `yaml:",omitempty"`
Username string `yaml:",omitempty"`
Password string `yaml:",omitempty"`
KeyTTL int64 `yaml:",omitempty"`
Mandatory bool `yaml:",omitempty"`
Mandatory interface{} `yaml:",omitempty"`
}

type etcdTLSConfig struct {
Expand All @@ -108,12 +111,12 @@ type subConfig struct {
Dump interface{} `yaml:",omitempty"`
DumpLimitSize interface{} `yaml:"dump_limit_size,omitempty"`
DumpLimitWindow interface{} `yaml:"dump_limit_window,omitempty"`
Etcd interface{} `yaml:",omitempty"`
Etcd etcdConfig `yaml:",omitempty"`
EtcdStatsInterval interface{} `yaml:"etcd_stats_interval,omitempty"`
LogStatsInterval interface{} `yaml:"log_stats_interval,omitempty"`
BlockingCommand interface{} `yaml:"blocking_command,omitempty"`
BgCommand interface{} `yaml:"bg_command,omitempty"`
SSH interface{} `yaml:",omitempty"`
SSH sshConfig `yaml:",omitempty"`
TranslateCommands map[string]*TranslateCommandConfig `yaml:"translate_commands,omitempty"`
Environment map[string]string `yaml:",omitempty"`
Service interface{} `yaml:",omitempty"`
Expand Down Expand Up @@ -143,8 +146,15 @@ func PrintConfig(config *Config, groups map[string]bool) []string {
output = append(output, fmt.Sprintf("config.blocking_command = %s", config.BlockingCommand))
output = append(output, fmt.Sprintf("config.bg_command = %s", config.BgCommand))
output = append(output, fmt.Sprintf("config.ssh = %+v", config.SSH))
for k, v := range config.TranslateCommands {
output = append(output, fmt.Sprintf("config.TranslateCommands.%s = %+v", k, v))
// Internally, we don't care of TranslateCommands's order. But we want to
// always display it in the same order
keys := make([]string, 0, len(config.TranslateCommands))
for k := range config.TranslateCommands {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
output = append(output, fmt.Sprintf("config.TranslateCommands.%s = %+v", k, config.TranslateCommands[k]))
}
output = append(output, fmt.Sprintf("config.environment = %v", config.Environment))
output = append(output, fmt.Sprintf("config.service = %s", config.Service))
Expand All @@ -168,6 +178,9 @@ func parseSubConfig(config *Config, subconfig *subConfig) error {
}

if subconfig.CheckInterval != nil {
if fmt.Sprintf("%T", subconfig.CheckInterval) != "string" {
return fmt.Errorf("check_interval: %v is not a string", subconfig.CheckInterval)
}
var err error
config.CheckInterval, err = time.ParseDuration(subconfig.CheckInterval.(string))
if err != nil {
Expand All @@ -188,18 +201,52 @@ func parseSubConfig(config *Config, subconfig *subConfig) error {
}

if subconfig.DumpLimitWindow != nil {
if fmt.Sprintf("%T", subconfig.DumpLimitWindow) != "string" {
return fmt.Errorf("dump_limit_window: %v is not a string", subconfig.DumpLimitWindow)
}
var err error
config.DumpLimitWindow, err = time.ParseDuration(subconfig.DumpLimitWindow.(string))
if err != nil {
return err
}
}

if subconfig.Etcd != nil {
config.Etcd = subconfig.Etcd.(etcdConfig)
if subconfig.Etcd.Endpoints != nil {
config.Etcd.Endpoints = subconfig.Etcd.Endpoints
}

if subconfig.Etcd.TLS.CAFile != "" {
config.Etcd.TLS.CAFile = subconfig.Etcd.TLS.CAFile
}

if subconfig.Etcd.TLS.KeyFile != "" {
config.Etcd.TLS.KeyFile = subconfig.Etcd.TLS.KeyFile
}

if subconfig.Etcd.TLS.CertFile != "" {
config.Etcd.TLS.CertFile = subconfig.Etcd.TLS.CertFile
}

if subconfig.Etcd.Username != "" {
config.Etcd.Username = subconfig.Etcd.Username
}

if subconfig.Etcd.Password != "" {
config.Etcd.Password = subconfig.Etcd.Password
}

if subconfig.Etcd.KeyTTL != 0 {
config.Etcd.KeyTTL = subconfig.Etcd.KeyTTL
}

if subconfig.Etcd.Mandatory != nil {
config.Etcd.Mandatory = subconfig.Etcd.Mandatory
}

if subconfig.EtcdStatsInterval != nil {
if fmt.Sprintf("%T", subconfig.EtcdStatsInterval) != "string" {
return fmt.Errorf("etcd_stats_interval: %v is not a string", subconfig.EtcdStatsInterval)
}
var err error
config.EtcdStatsInterval, err = time.ParseDuration(subconfig.EtcdStatsInterval.(string))
if err != nil {
Expand All @@ -208,6 +255,9 @@ func parseSubConfig(config *Config, subconfig *subConfig) error {
}

if subconfig.LogStatsInterval != nil {
if fmt.Sprintf("%T", subconfig.LogStatsInterval) != "string" {
return fmt.Errorf("log_stats_interval: %v is not a string", subconfig.LogStatsInterval)
}
var err error
config.LogStatsInterval, err = time.ParseDuration(subconfig.LogStatsInterval.(string))
if err != nil {
Expand All @@ -223,8 +273,12 @@ func parseSubConfig(config *Config, subconfig *subConfig) error {
config.BgCommand = subconfig.BgCommand.(string)
}

if subconfig.SSH != nil {
config.SSH = subconfig.SSH.(sshConfig)
if subconfig.SSH.Exe != "" {
config.SSH.Exe = subconfig.SSH.Exe
}

if subconfig.SSH.Args != nil {
config.SSH.Args = subconfig.SSH.Args
}

// merge translate_commands
Expand Down Expand Up @@ -296,7 +350,14 @@ func LoadAllDestsFromConfig(filename string) ([]string, error) {
config.Dest = append(config.Dest, override.Dest...)
}
}
return config.Dest, nil
// expand destination nodesets
_, nodesetDlclose, nodesetExpand := nodesets.InitExpander()
defer nodesetDlclose()
dsts, err := nodesetExpand(strings.Join(config.Dest, ","))
if err != nil {
return nil, fmt.Errorf("invalid nodeset: %s", err)
}
return dsts, nil
}

// LoadConfig load configuration file and adapt it according to specified user/group/sshdHostPort.
Expand Down
Loading