Skip to content

Commit 82adb40

Browse files
committed
add tests for config.go and resolve a few config bugs
1 parent 9feed80 commit 82adb40

20 files changed

+615
-15
lines changed

cmd/sshproxy/sshproxy.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ func mainExitCode() int {
296296
}
297297
}
298298
} else {
299-
if config.Etcd.Mandatory {
299+
if config.Etcd.Mandatory.(bool) {
300300
log.Fatal("Etcd is mandatory but unavailable")
301301
}
302302
}

config/sshproxy.yaml

+1-4
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,7 @@
204204
# ENV1: /tmp/env
205205
# ssh:
206206
# args: ["-vvv", "-Y"]
207-
# # If routes are specified, each specified route is fully overridden, not merged.
208-
# routes:
209-
# default:
210-
# dest: [hostx]
207+
# dest: [hostx]
211208
# - match:
212209
# - groups: [bar]
213210
# groups: [baz]

pkg/utils/config.go

+71-10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"os"
1717
"regexp"
1818
"slices"
19+
"sort"
1920
"strings"
2021
"time"
2122

@@ -82,13 +83,15 @@ type sshConfig struct {
8283
Args []string `yaml:",flow,omitempty"`
8384
}
8485

86+
// We use interface{} instead of real type to check if the option was specified
87+
// or not.
8588
type etcdConfig struct {
8689
Endpoints []string `yaml:",flow"`
8790
TLS etcdTLSConfig `yaml:",omitempty"`
8891
Username string `yaml:",omitempty"`
8992
Password string `yaml:",omitempty"`
9093
KeyTTL int64 `yaml:",omitempty"`
91-
Mandatory bool `yaml:",omitempty"`
94+
Mandatory interface{} `yaml:",omitempty"`
9295
}
9396

9497
type etcdTLSConfig struct {
@@ -108,12 +111,12 @@ type subConfig struct {
108111
Dump interface{} `yaml:",omitempty"`
109112
DumpLimitSize interface{} `yaml:"dump_limit_size,omitempty"`
110113
DumpLimitWindow interface{} `yaml:"dump_limit_window,omitempty"`
111-
Etcd interface{} `yaml:",omitempty"`
114+
Etcd etcdConfig `yaml:",omitempty"`
112115
EtcdStatsInterval interface{} `yaml:"etcd_stats_interval,omitempty"`
113116
LogStatsInterval interface{} `yaml:"log_stats_interval,omitempty"`
114117
BlockingCommand interface{} `yaml:"blocking_command,omitempty"`
115118
BgCommand interface{} `yaml:"bg_command,omitempty"`
116-
SSH interface{} `yaml:",omitempty"`
119+
SSH sshConfig `yaml:",omitempty"`
117120
TranslateCommands map[string]*TranslateCommandConfig `yaml:"translate_commands,omitempty"`
118121
Environment map[string]string `yaml:",omitempty"`
119122
Service interface{} `yaml:",omitempty"`
@@ -143,8 +146,15 @@ func PrintConfig(config *Config, groups map[string]bool) []string {
143146
output = append(output, fmt.Sprintf("config.blocking_command = %s", config.BlockingCommand))
144147
output = append(output, fmt.Sprintf("config.bg_command = %s", config.BgCommand))
145148
output = append(output, fmt.Sprintf("config.ssh = %+v", config.SSH))
146-
for k, v := range config.TranslateCommands {
147-
output = append(output, fmt.Sprintf("config.TranslateCommands.%s = %+v", k, v))
149+
// Internally, we don't care of TranslateCommands's order. But we want to
150+
// always display it in the same order
151+
keys := make([]string, 0, len(config.TranslateCommands))
152+
for k := range config.TranslateCommands {
153+
keys = append(keys, k)
154+
}
155+
sort.Strings(keys)
156+
for _, k := range keys {
157+
output = append(output, fmt.Sprintf("config.TranslateCommands.%s = %+v", k, config.TranslateCommands[k]))
148158
}
149159
output = append(output, fmt.Sprintf("config.environment = %v", config.Environment))
150160
output = append(output, fmt.Sprintf("config.service = %s", config.Service))
@@ -168,6 +178,9 @@ func parseSubConfig(config *Config, subconfig *subConfig) error {
168178
}
169179

170180
if subconfig.CheckInterval != nil {
181+
if fmt.Sprintf("%T", subconfig.CheckInterval) != "string" {
182+
return fmt.Errorf("check_interval: %v is not a string", subconfig.CheckInterval)
183+
}
171184
var err error
172185
config.CheckInterval, err = time.ParseDuration(subconfig.CheckInterval.(string))
173186
if err != nil {
@@ -188,18 +201,52 @@ func parseSubConfig(config *Config, subconfig *subConfig) error {
188201
}
189202

190203
if subconfig.DumpLimitWindow != nil {
204+
if fmt.Sprintf("%T", subconfig.DumpLimitWindow) != "string" {
205+
return fmt.Errorf("dump_limit_window: %v is not a string", subconfig.DumpLimitWindow)
206+
}
191207
var err error
192208
config.DumpLimitWindow, err = time.ParseDuration(subconfig.DumpLimitWindow.(string))
193209
if err != nil {
194210
return err
195211
}
196212
}
197213

198-
if subconfig.Etcd != nil {
199-
config.Etcd = subconfig.Etcd.(etcdConfig)
214+
if subconfig.Etcd.Endpoints != nil {
215+
config.Etcd.Endpoints = subconfig.Etcd.Endpoints
216+
}
217+
218+
if subconfig.Etcd.TLS.CAFile != "" {
219+
config.Etcd.TLS.CAFile = subconfig.Etcd.TLS.CAFile
220+
}
221+
222+
if subconfig.Etcd.TLS.KeyFile != "" {
223+
config.Etcd.TLS.KeyFile = subconfig.Etcd.TLS.KeyFile
224+
}
225+
226+
if subconfig.Etcd.TLS.CertFile != "" {
227+
config.Etcd.TLS.CertFile = subconfig.Etcd.TLS.CertFile
228+
}
229+
230+
if subconfig.Etcd.Username != "" {
231+
config.Etcd.Username = subconfig.Etcd.Username
232+
}
233+
234+
if subconfig.Etcd.Password != "" {
235+
config.Etcd.Password = subconfig.Etcd.Password
236+
}
237+
238+
if subconfig.Etcd.KeyTTL != 0 {
239+
config.Etcd.KeyTTL = subconfig.Etcd.KeyTTL
240+
}
241+
242+
if subconfig.Etcd.Mandatory != nil {
243+
config.Etcd.Mandatory = subconfig.Etcd.Mandatory
200244
}
201245

202246
if subconfig.EtcdStatsInterval != nil {
247+
if fmt.Sprintf("%T", subconfig.EtcdStatsInterval) != "string" {
248+
return fmt.Errorf("etcd_stats_interval: %v is not a string", subconfig.EtcdStatsInterval)
249+
}
203250
var err error
204251
config.EtcdStatsInterval, err = time.ParseDuration(subconfig.EtcdStatsInterval.(string))
205252
if err != nil {
@@ -208,6 +255,9 @@ func parseSubConfig(config *Config, subconfig *subConfig) error {
208255
}
209256

210257
if subconfig.LogStatsInterval != nil {
258+
if fmt.Sprintf("%T", subconfig.LogStatsInterval) != "string" {
259+
return fmt.Errorf("log_stats_interval: %v is not a string", subconfig.LogStatsInterval)
260+
}
211261
var err error
212262
config.LogStatsInterval, err = time.ParseDuration(subconfig.LogStatsInterval.(string))
213263
if err != nil {
@@ -223,8 +273,12 @@ func parseSubConfig(config *Config, subconfig *subConfig) error {
223273
config.BgCommand = subconfig.BgCommand.(string)
224274
}
225275

226-
if subconfig.SSH != nil {
227-
config.SSH = subconfig.SSH.(sshConfig)
276+
if subconfig.SSH.Exe != "" {
277+
config.SSH.Exe = subconfig.SSH.Exe
278+
}
279+
280+
if subconfig.SSH.Args != nil {
281+
config.SSH.Args = subconfig.SSH.Args
228282
}
229283

230284
// merge translate_commands
@@ -296,7 +350,14 @@ func LoadAllDestsFromConfig(filename string) ([]string, error) {
296350
config.Dest = append(config.Dest, override.Dest...)
297351
}
298352
}
299-
return config.Dest, nil
353+
// expand destination nodesets
354+
_, nodesetDlclose, nodesetExpand := nodesets.InitExpander()
355+
defer nodesetDlclose()
356+
dsts, err := nodesetExpand(strings.Join(config.Dest, ","))
357+
if err != nil {
358+
return nil, fmt.Errorf("invalid nodeset: %s", err)
359+
}
360+
return dsts, nil
300361
}
301362

302363
// LoadConfig load configuration file and adapt it according to specified user/group/sshdHostPort.

0 commit comments

Comments
 (0)