diff --git a/internal/testing/config.go b/internal/testing/config.go index 0521bc298..c58db27ea 100644 --- a/internal/testing/config.go +++ b/internal/testing/config.go @@ -60,7 +60,13 @@ func NewIntegrationTestConfig(t *testing.T) config.Config { cfg.AdminAPIKey = envAdminAPIKey if envRegion != "" { - err := cfg.SetRegion(region.Parse(envRegion)) + regName, err := region.Parse(envRegion) + assert.NoError(t, err) + + reg, err := region.Get(regName) + assert.NoError(t, err) + + err = cfg.SetRegion(reg) assert.NoError(t, err) } diff --git a/newrelic/newrelic.go b/newrelic/newrelic.go index f660b4028..a3f516106 100644 --- a/newrelic/newrelic.go +++ b/newrelic/newrelic.go @@ -91,12 +91,14 @@ func ConfigAdminAPIKey(adminAPIKey string) ConfigOption { // ConfigRegion sets the New Relic Region this client will use. func ConfigRegion(r region.Name) ConfigOption { return func(cfg *config.Config) error { - if region, ok := region.Regions[r]; ok { - regCopy := *region - return cfg.SetRegion(®Copy) + reg, err := region.Get(r) + if err != nil { + return err } - return errors.New("unsupported region configured") + err = cfg.SetRegion(reg) + + return err } } diff --git a/pkg/config/config.go b/pkg/config/config.go index 8d2bd7dd9..784d18e83 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -49,10 +49,10 @@ type Config struct { // New creates a default configuration and returns it func New() Config { - regCopy := *region.Default + reg, _ := region.Get(region.Default) return Config{ - region: ®Copy, + region: reg, UserAgent: "newrelic/newrelic-client-go", LogLevel: "info", } @@ -62,8 +62,8 @@ func New() Config { // if one has not been set, use the default region func (c *Config) Region() *region.Region { if c.region == nil { - regCopy := *region.Default - c.region = ®Copy + reg, _ := region.Get(region.Default) + c.region = reg } return c.region diff --git a/pkg/region/errors.go b/pkg/region/errors.go index db8c91b6b..9685268f6 100644 --- a/pkg/region/errors.go +++ b/pkg/region/errors.go @@ -24,3 +24,29 @@ func ErrorNil() InvalidError { Message: "value is nil", } } + +type UnknownError struct { + Message string +} + +func (e UnknownError) Error() string { + if e.Message != "" { + return fmt.Sprintf("unknown region: %s", e.Message) + } + + return "unknown region" +} + +// UnknownUsingDefaultError returns when the Region requested is not valid, but we want to give them something +type UnknownUsingDefaultError struct { + Message string +} + +// Error string reported when an InvalidError happens +func (e UnknownUsingDefaultError) Error() string { + if e.Message != "" { + return fmt.Sprintf("unknown region: %s, using default: %s", e.Message, Default.String()) + } + + return fmt.Sprintf("unknown region, using default: %s", Default.String()) +} diff --git a/pkg/region/errors_test.go b/pkg/region/errors_test.go new file mode 100644 index 000000000..c8d15373b --- /dev/null +++ b/pkg/region/errors_test.go @@ -0,0 +1,44 @@ +// +build unit + +package region + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestInvalidError(t *testing.T) { + t.Parallel() + + err := InvalidError{} + assert.EqualError(t, err, "invalid region") + + err = InvalidError{Message: "asdf"} + assert.EqualError(t, err, "invalid region: asdf") + + // Custom func for nils + err = ErrorNil() + assert.Error(t, err) + assert.EqualError(t, err, "invalid region: value is nil") +} + +func TestUnknownError(t *testing.T) { + t.Parallel() + + err := UnknownError{} + assert.EqualError(t, err, "unknown region") + + err = UnknownError{Message: "test"} + assert.EqualError(t, err, "unknown region: test") +} + +func TestUnknownUsingDefaultError(t *testing.T) { + t.Parallel() + + err := UnknownUsingDefaultError{} + assert.EqualError(t, err, "unknown region, using default: "+Default.String()) + + err = UnknownUsingDefaultError{Message: "test"} + assert.EqualError(t, err, "unknown region: test, using default: "+Default.String()) +} diff --git a/pkg/region/region.go b/pkg/region/region.go index 83c8d8c83..8c33cfbb5 100644 --- a/pkg/region/region.go +++ b/pkg/region/region.go @@ -24,6 +24,11 @@ type Region struct { nerdGraphBaseURL string } +// String returns a human readable value for the specified Region Name +func (n Name) String() string { + return string(n) +} + // String returns a human readable value for the specified Region func (r *Region) String() string { if r != nil && r.name != "" { diff --git a/pkg/region/region_constants.go b/pkg/region/region_constants.go index ff78d8330..509d0d1d6 100644 --- a/pkg/region/region_constants.go +++ b/pkg/region/region_constants.go @@ -41,22 +41,27 @@ var Regions = map[Name]*Region{ } // Default represents the region returned if nothing was specified -var Default *Region = Regions[US] +const Default Name = US // Parse takes a Region string and returns a RegionType -func Parse(r string) *Region { - var ret Region - +func Parse(r string) (Name, error) { switch strings.ToLower(r) { case "us": - ret = *Regions[US] + return US, nil case "eu": - ret = *Regions[EU] + return EU, nil case "staging": - ret = *Regions[Staging] + return Staging, nil default: - ret = *Default + return "", UnknownError{Message: r} + } +} + +func Get(r Name) (*Region, error) { + if reg, ok := Regions[r]; ok { + ret := *reg // Make a copy + return &ret, nil } - return &ret + return Regions[Default], UnknownUsingDefaultError{Message: r.String()} } diff --git a/pkg/region/region_test.go b/pkg/region/region_test.go index ca69178fd..322d747c8 100644 --- a/pkg/region/region_test.go +++ b/pkg/region/region_test.go @@ -11,31 +11,57 @@ import ( func TestParse(t *testing.T) { t.Parallel() - pairs := map[string]*Region{ - "us": Regions[US], - "Us": Regions[US], - "uS": Regions[US], - "US": Regions[US], - "eu": Regions[EU], - "Eu": Regions[EU], - "eU": Regions[EU], - "EU": Regions[EU], - "staging": Regions[Staging], - "Staging": Regions[Staging], - "STAGING": Regions[Staging], + pairs := map[string]Name{ + "us": US, + "Us": US, + "uS": US, + "US": US, + "eu": EU, + "Eu": EU, + "eU": EU, + "EU": EU, + "staging": Staging, + "Staging": Staging, + "STAGING": Staging, } for k, v := range pairs { - result := Parse(k) - assert.Equal(t, result, v) + result, err := Parse(k) + assert.NoError(t, err) + assert.Equal(t, v, result) } // Default is US - result := Parse("") - assert.Equal(t, result, Regions[US]) + result, err := Parse("") + assert.Error(t, err) + assert.IsType(t, UnknownError{}, err) + assert.Equal(t, Name(""), result) +} + +func TestRegionGet(t *testing.T) { + t.Parallel() + + pairs := map[Name]*Region{ + US: Regions[US], + EU: Regions[EU], + Staging: Regions[Staging], + } + + for k, v := range pairs { + result, err := Get(k) + assert.NoError(t, err) + assert.Equal(t, v, result) + } + + // Throws error, still returns the default + var unk Name = "(unknown)" + result, err := Get(unk) + assert.Error(t, err) + assert.IsType(t, UnknownUsingDefaultError{}, err) + assert.Equal(t, Regions[Default], result) } -func TestString(t *testing.T) { +func TestRegionString(t *testing.T) { t.Parallel() pairs := map[Name]string{