Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix config bug where mockery crashes when package map is nil #730

Merged
merged 7 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ tasks:

test:
cmds:
- go test -v -coverprofile=coverage.txt ./...
- go run gotest.tools/gotestsum --format testname -- -v -coverprofile=coverage.txt ./...
desc: run unit tests
sources:
- "**/*.go"
Expand Down
10 changes: 6 additions & 4 deletions cmd/showconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ func showConfig(
if err != nil {
return stackerr.NewStackErrf(err, "failed to unmarshal config")
}
log, err := logging.GetLogger(config.LogLevel)
if err != nil {
return fmt.Errorf("getting logger: %w", err)
}
ctx = log.WithContext(ctx)
if err := config.Initialize(ctx); err != nil {
return err
}
Expand All @@ -48,10 +53,7 @@ func showConfig(
if err != nil {
return stackerr.NewStackErrf(err, "failed to marshal yaml")
}
log, err := logging.GetLogger(config.LogLevel)
if err != nil {
panic(err)
}

log.Info().Msgf("Using config: %s", config.Config)

fmt.Fprintf(outputter, "%s", string(out))
Expand Down
151 changes: 114 additions & 37 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,14 @@ func (c *Config) GetPackages(ctx context.Context) ([]string, error) {
return packageList, nil
}

// getPackageConfigMap returns the map for the particular package, which includes
// (but is not limited to) both the `configs` section and the `interfaces` section.
// Note this does NOT return the `configs` section for the package. It returns the
// entire mapping for the package.
func (c *Config) getPackageConfigMap(ctx context.Context, packageName string) (map[string]any, error) {
log := zerolog.Ctx(ctx)
log.Trace().Msg("getting package config map")

cfgMap, err := c.CfgAsMap(ctx)
if err != nil {
return nil, err
Expand All @@ -207,13 +214,27 @@ func (c *Config) getPackageConfigMap(ctx context.Context, packageName string) (m
}
configAsMap, isMap := configUnmerged.(map[string]any)
if isMap {
log.Trace().Msg("package's value is a map, returning")
return configAsMap, nil
}
return map[string]any{}, nil
log.Trace().Msg("package's value is not a map")

// Package is something other than map, so set its value to an
// empty map.
emptyMap := map[string]any{}
packageSection[packageName] = emptyMap
return emptyMap, nil

}

// GetPackageConfig returns a struct representation of the package's config
// as provided in yaml. If the package did not specify a config section,
// this method will inject the top-level config into the package's config.
// This is especially useful as it allows us to lazily evaluate a package's
// config. If the package does specify config, this method takes care to merge
// the top-level config with the values specified for this package.
func (c *Config) GetPackageConfig(ctx context.Context, packageName string) (*Config, error) {
log := zerolog.Ctx(ctx)
log := zerolog.Ctx(ctx).With().Str("package-path", packageName).Logger()

if c.pkgConfigCache == nil {
log.Debug().Msg("package cache is nil")
Expand All @@ -223,11 +244,10 @@ func (c *Config) GetPackageConfig(ctx context.Context, packageName string) (*Con
return pkgConf, nil
}

pkgConfig := reflect.New(reflect.ValueOf(c).Elem().Type()).Interface()
pkgConfig := &Config{}
if err := copier.Copy(pkgConfig, c); err != nil {
return nil, fmt.Errorf("failed to copy config: %w", err)
}
pkgConfigTyped := pkgConfig.(*Config)

configMap, err := c.getPackageConfigMap(ctx, packageName)
if err != nil {
Expand All @@ -237,18 +257,24 @@ func (c *Config) GetPackageConfig(ctx context.Context, packageName string) (*Con
configSection, ok := configMap["config"]
if !ok {
log.Debug().Msg("config section not provided for package")
return pkgConfigTyped, nil
configMap["config"] = map[string]any{}
c.pkgConfigCache[packageName] = pkgConfig
return pkgConfig, nil
}

decoder, err := c.getDecoder(pkgConfigTyped)
// We know that the package specified config that is overriding the top-level
// config. We use a mapstructure decoder to decode the values in the yaml
// into the pkgConfig struct. This has the effect of merging top-level
// config with package-level config.
decoder, err := c.getDecoder(pkgConfig)
if err != nil {
return nil, stackerr.NewStackErrf(err, "failed to get decoder")
}
if err := decoder.Decode(configSection); err != nil {
return nil, err
}
c.pkgConfigCache[packageName] = pkgConfigTyped
return pkgConfigTyped, nil
c.pkgConfigCache[packageName] = pkgConfig
return pkgConfig, nil
}

func (c *Config) ExcludePath(path string) bool {
Expand Down Expand Up @@ -347,16 +373,15 @@ func (c *Config) GetInterfaceConfig(ctx context.Context, packageName string, int
}

// Copy the package-level config to our interface-level config
pkgConfigCopy := reflect.New(reflect.ValueOf(pkgConfig).Elem().Type()).Interface()
pkgConfigCopy := &Config{}
if err := copier.Copy(pkgConfigCopy, pkgConfig); err != nil {
return nil, stackerr.NewStackErrf(err, "failed to create a copy of package config")
}
baseConfigTyped := pkgConfigCopy.(*Config)

interfaceSection, ok := interfacesSection[interfaceName]
if !ok {
log.Debug().Msg("interface not defined in package configuration")
return []*Config{baseConfigTyped}, nil
return []*Config{pkgConfigCopy}, nil
}

interfaceSectionTyped, ok := interfaceSection.(map[string]any)
Expand All @@ -365,7 +390,7 @@ func (c *Config) GetInterfaceConfig(ctx context.Context, packageName string, int
// the interface but not provide any additional config beyond what
// is provided at the package level
if reflect.ValueOf(&interfaceSection).Elem().IsZero() {
return []*Config{baseConfigTyped}, nil
return []*Config{pkgConfigCopy}, nil
}
msgString := "bad type provided for interface config"
log.Error().Msgf(msgString)
Expand All @@ -376,11 +401,11 @@ func (c *Config) GetInterfaceConfig(ctx context.Context, packageName string, int
if ok {
log.Debug().Msg("config section exists for interface")
// if `config` is provided, we'll overwrite the values in our
// baseConfigTyped struct to act as the "new" base config.
// pkgConfigCopy struct to act as the "new" base config.
// This will allow us to set the default values for the interface
// but override them further for each mock defined in the
// `configs` section.
decoder, err := c.getDecoder(baseConfigTyped)
decoder, err := c.getDecoder(pkgConfigCopy)
if err != nil {
return nil, stackerr.NewStackErrf(err, "unable to create mapstructure decoder")
}
Expand All @@ -397,8 +422,8 @@ func (c *Config) GetInterfaceConfig(ctx context.Context, packageName string, int
configsSectionTyped := configsSection.([]any)
for _, configMap := range configsSectionTyped {
// Create a copy of the package-level config
currentInterfaceConfig := reflect.New(reflect.ValueOf(baseConfigTyped).Elem().Type()).Interface()
if err := copier.Copy(currentInterfaceConfig, baseConfigTyped); err != nil {
currentInterfaceConfig := reflect.New(reflect.ValueOf(pkgConfigCopy).Elem().Type()).Interface()
if err := copier.Copy(currentInterfaceConfig, pkgConfigCopy); err != nil {
return nil, stackerr.NewStackErrf(err, "failed to copy package config")
}

Expand All @@ -418,7 +443,7 @@ func (c *Config) GetInterfaceConfig(ctx context.Context, packageName string, int
log.Debug().Msg("configs section doesn't exist for interface")

if len(configs) == 0 {
configs = append(configs, baseConfigTyped)
configs = append(configs, pkgConfigCopy)
}
return configs, nil
}
Expand All @@ -429,31 +454,33 @@ func (c *Config) addSubPkgConfig(ctx context.Context, subPkgPath string, parentP
log := zerolog.Ctx(ctx).With().
Str("parent-package", parentPkgPath).
Str("sub-package", subPkgPath).Logger()
ctx = log.WithContext(ctx)

log.Trace().Msg("adding sub-package to config map")
log.Debug().Msg("adding sub-package to config map")
parentPkgConfig, err := c.getPackageConfigMap(ctx, parentPkgPath)
if err != nil {
log.Err(err).
Msg("failed to get package config for parent package")
return fmt.Errorf("failed to get package config: %w", err)
}

log.Trace().Msg("getting config")
cfg, err := c.CfgAsMap(ctx)
log.Debug().Msg("getting config")
topLevelConfig, err := c.CfgAsMap(ctx)
if err != nil {
return fmt.Errorf("failed to get configuration map: %w", err)
}

log.Trace().Msg("getting packages section")
packagesSection := cfg["packages"].(map[string]any)
log.Debug().Msg("getting packages section")
packagesSection := topLevelConfig["packages"].(map[string]any)

// Don't overwrite any config that already exists
_, pkgExists := packagesSection[subPkgPath]
if !pkgExists {
log.Trace().Msg("sub-package doesn't exist in config")

// Copy the parent package directly into the subpackage config section
packagesSection[subPkgPath] = map[string]any{}
newPkgSection := packagesSection[subPkgPath].(map[string]any)
newPkgSection["config"] = parentPkgConfig["config"]
newPkgSection["config"] = deepCopyConfigMap(parentPkgConfig["config"].(map[string]any))
} else {
log.Trace().Msg("sub-package exists in config")
// The sub-package exists in config. Check if it has its
Expand All @@ -465,10 +492,15 @@ func (c *Config) addSubPkgConfig(ctx context.Context, subPkgPath string, parentP
log.Err(err).Msg("could not get child package config")
return fmt.Errorf("failed to get sub-package config: %w", err)
}

for key, val := range parentPkgConfig {
if _, keyInSubPkg := subPkgConfig[key]; !keyInSubPkg {
subPkgConfig[key] = val
log.Trace().Msgf("sub-package config: %v", subPkgConfig)
log.Trace().Msgf("parent-package config: %v", parentPkgConfig)

// Merge the parent config with the sub-package config.
parentConfigSection := parentPkgConfig["config"].(map[string]any)
subPkgConfigSection := subPkgConfig["config"].(map[string]any)
for key, val := range parentConfigSection {
if _, keyInSubPkg := subPkgConfigSection[key]; !keyInSubPkg {
subPkgConfigSection[key] = val
}

}
Expand Down Expand Up @@ -595,18 +627,24 @@ func (c *Config) subPackages(
// recursive and recurses the file tree to find all sub-packages.
func (c *Config) discoverRecursivePackages(ctx context.Context) error {
log := zerolog.Ctx(ctx)
log.Trace().Msg("discovering recursive packages")
recursivePackages := map[string]*Config{}
packageList, err := c.GetPackages(ctx)
if err != nil {
return fmt.Errorf("failed to get packages: %w", err)
}
for _, pkg := range packageList {
pkgConfig, err := c.GetPackageConfig(ctx, pkg)
pkgLog := log.With().Str("package", pkg).Logger()
pkgLog.Trace().Msg("iterating over package")
if err != nil {
return fmt.Errorf("failed to get package config: %w", err)
}
if pkgConfig.Recursive {
pkgLog.Trace().Msg("package marked as recursive")
recursivePackages[pkg] = pkgConfig
} else {
pkgLog.Trace().Msg("package not marked as recursive")
}
}
if len(recursivePackages) == 0 {
Expand Down Expand Up @@ -658,6 +696,17 @@ func contains[T comparable](slice []T, elem T) bool {
return false
}

func deepCopyConfigMap(src map[string]any) map[string]any {
newMap := map[string]any{}
for key, val := range src {
if contains([]string{"packages", "config", "interfaces"}, key) {
continue
}
newMap[key] = val
}
return newMap
}

// mergeInConfig takes care of merging inheritable configuration
// in the config map. For example, it merges default config, then
// package-level config, then interface-level config.
Expand All @@ -677,34 +726,62 @@ func (c *Config) mergeInConfig(ctx context.Context) error {
}
for _, pkgPath := range pkgs {
pkgLog := log.With().Str("package-path", pkgPath).Logger()
pkgCtx := pkgLog.WithContext(ctx)

pkgLog.Trace().Msg("merging for package")
packageConfig, err := c.getPackageConfigMap(ctx, pkgPath)
packageConfig, err := c.getPackageConfigMap(pkgCtx, pkgPath)
if err != nil {
pkgLog.Err(err).Msg("failed to get package config")
return fmt.Errorf("failed to get package config: %w", err)
}
pkgLog.Trace().Msgf("got package config map: %v", packageConfig)

configSectionUntyped, configExists := packageConfig["config"]
if !configExists {
packageConfig["config"] = defaultCfg
continue
// The reason why this should never happen is because getPackageConfigMap
// should be populating the config section with the top-level config if it
// wasn't defined in the yaml.
msg := "config section does not exist for package, this should never happen"
pkgLog.Error().Msg(msg)
return fmt.Errorf(msg)
}

pkgLog.Trace().Msg("got config section for package")
// Sometimes the config section may be provided, but it's nil.
// We need to account for this fact.
if configSectionUntyped == nil {
pkgLog.Trace().Msg("config section is nil, converting to empty map")
emptyMap := map[string]any{}

// We need to add this to the "global" config mapping so the change
// gets persisted, and also into configSectionUntyped for the logic
// further down.
packageConfig["config"] = emptyMap
configSectionUntyped = emptyMap
} else {
pkgLog.Trace().Msg("config section is not nil")
}
packageConfigSection := configSectionUntyped.(map[string]any)

configSectionTyped := configSectionUntyped.(map[string]any)

for key, value := range defaultCfg {
if contains([]string{"packages", "config"}, key) {
continue
}
_, keyExists := packageConfigSection[key]
keyValLog := pkgLog.With().Str("key", key).Str("value", fmt.Sprintf("%v", value)).Logger()

_, keyExists := configSectionTyped[key]
if !keyExists {
packageConfigSection[key] = value
keyValLog.Trace().Msg("setting key to value")
configSectionTyped[key] = value
}
}
interfaces, err := c.getInterfacesForPackage(ctx, pkgPath)
interfaces, err := c.getInterfacesForPackage(pkgCtx, pkgPath)
if err != nil {
return fmt.Errorf("failed to get interfaces for package: %w", err)
}
for _, interfaceName := range interfaces {
interfacesSection, err := c.getInterfacesSection(ctx, pkgPath)
interfacesSection, err := c.getInterfacesSection(pkgCtx, pkgPath)
if err != nil {
return err
}
Expand All @@ -728,7 +805,7 @@ func (c *Config) mergeInConfig(ctx context.Context) error {
// Assume this interface's value in the map is nil. Just skip it.
continue
}
for key, value := range packageConfigSection {
for key, value := range configSectionTyped {
if key == "packages" {
continue
}
Expand Down
Loading