diff --git a/jwa/BUILD.bazel b/jwa/BUILD.bazel index 63fffcf90..aa05d4b3a 100644 --- a/jwa/BUILD.bazel +++ b/jwa/BUILD.bazel @@ -9,10 +9,14 @@ go_library( "jwa.go", "key_encryption_gen.go", "key_type_gen.go", + "options_gen.go", "signature_gen.go", ], importpath = "github.com/lestrrat-go/jwx/v2/jwa", visibility = ["//visibility:public"], + deps = [ + "@com_github_lestrrat_go_option//:option", + ], ) go_test( @@ -29,6 +33,7 @@ go_test( deps = [ ":jwa", "@com_github_stretchr_testify//assert", + "@com_github_lestrrat_go_option//:option", ], ) diff --git a/jwa/compression_gen_test.go b/jwa/compression_gen_test.go index 016e6093e..109119912 100644 --- a/jwa/compression_gen_test.go +++ b/jwa/compression_gen_test.go @@ -114,3 +114,58 @@ func TestCompressionAlgorithm(t *testing.T) { } }) } + +// Note: this test can NOT be run in parallel as it uses options with global effect. +func TestCompressionAlgorithmCustomAlgorithm(t *testing.T) { + // These subtests can NOT be run in parallel as options with global effect change. + customAlgorithm := jwa.CompressionAlgorithm("custom-algorithm") + // Unregister the custom algorithm, in case tests fail. + t.Cleanup(func() { + jwa.UnregisterCompressionAlgorithm(customAlgorithm) + }) + t.Run(`with custom algorithm registered`, func(t *testing.T) { + jwa.RegisterCompressionAlgorithm(customAlgorithm) + t.Run(`accept variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.CompressionAlgorithm + if !assert.NoError(t, dst.Accept(customAlgorithm), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.CompressionAlgorithm + if !assert.NoError(t, dst.Accept(`custom-algorithm`), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.CompressionAlgorithm + if !assert.NoError(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + }) + t.Run(`with custom algorithm deregistered`, func(t *testing.T) { + jwa.UnregisterCompressionAlgorithm(customAlgorithm) + t.Run(`reject variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.CompressionAlgorithm + assert.Error(t, dst.Accept(customAlgorithm), `accept failed`) + }) + t.Run(`reject the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.CompressionAlgorithm + assert.Error(t, dst.Accept(`custom-algorithm`), `accept failed`) + }) + t.Run(`reject fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.CompressionAlgorithm + assert.Error(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept failed`) + }) + }) +} diff --git a/jwa/content_encryption_gen_test.go b/jwa/content_encryption_gen_test.go index 7eec98652..0dff6252c 100644 --- a/jwa/content_encryption_gen_test.go +++ b/jwa/content_encryption_gen_test.go @@ -262,3 +262,58 @@ func TestContentEncryptionAlgorithm(t *testing.T) { } }) } + +// Note: this test can NOT be run in parallel as it uses options with global effect. +func TestContentEncryptionAlgorithmCustomAlgorithm(t *testing.T) { + // These subtests can NOT be run in parallel as options with global effect change. + customAlgorithm := jwa.ContentEncryptionAlgorithm("custom-algorithm") + // Unregister the custom algorithm, in case tests fail. + t.Cleanup(func() { + jwa.UnregisterContentEncryptionAlgorithm(customAlgorithm) + }) + t.Run(`with custom algorithm registered`, func(t *testing.T) { + jwa.RegisterContentEncryptionAlgorithm(customAlgorithm) + t.Run(`accept variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.ContentEncryptionAlgorithm + if !assert.NoError(t, dst.Accept(customAlgorithm), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.ContentEncryptionAlgorithm + if !assert.NoError(t, dst.Accept(`custom-algorithm`), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.ContentEncryptionAlgorithm + if !assert.NoError(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + }) + t.Run(`with custom algorithm deregistered`, func(t *testing.T) { + jwa.UnregisterContentEncryptionAlgorithm(customAlgorithm) + t.Run(`reject variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.ContentEncryptionAlgorithm + assert.Error(t, dst.Accept(customAlgorithm), `accept failed`) + }) + t.Run(`reject the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.ContentEncryptionAlgorithm + assert.Error(t, dst.Accept(`custom-algorithm`), `accept failed`) + }) + t.Run(`reject fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.ContentEncryptionAlgorithm + assert.Error(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept failed`) + }) + }) +} diff --git a/jwa/elliptic_gen_test.go b/jwa/elliptic_gen_test.go index badec268d..4646cadd9 100644 --- a/jwa/elliptic_gen_test.go +++ b/jwa/elliptic_gen_test.go @@ -311,3 +311,58 @@ func TestEllipticCurveAlgorithm(t *testing.T) { } }) } + +// Note: this test can NOT be run in parallel as it uses options with global effect. +func TestEllipticCurveAlgorithmCustomAlgorithm(t *testing.T) { + // These subtests can NOT be run in parallel as options with global effect change. + customAlgorithm := jwa.EllipticCurveAlgorithm("custom-algorithm") + // Unregister the custom algorithm, in case tests fail. + t.Cleanup(func() { + jwa.UnregisterEllipticCurveAlgorithm(customAlgorithm) + }) + t.Run(`with custom algorithm registered`, func(t *testing.T) { + jwa.RegisterEllipticCurveAlgorithm(customAlgorithm) + t.Run(`accept variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.EllipticCurveAlgorithm + if !assert.NoError(t, dst.Accept(customAlgorithm), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.EllipticCurveAlgorithm + if !assert.NoError(t, dst.Accept(`custom-algorithm`), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.EllipticCurveAlgorithm + if !assert.NoError(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + }) + t.Run(`with custom algorithm deregistered`, func(t *testing.T) { + jwa.UnregisterEllipticCurveAlgorithm(customAlgorithm) + t.Run(`reject variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.EllipticCurveAlgorithm + assert.Error(t, dst.Accept(customAlgorithm), `accept failed`) + }) + t.Run(`reject the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.EllipticCurveAlgorithm + assert.Error(t, dst.Accept(`custom-algorithm`), `accept failed`) + }) + t.Run(`reject fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.EllipticCurveAlgorithm + assert.Error(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept failed`) + }) + }) +} diff --git a/jwa/jwa.go b/jwa/jwa.go index f9ce38e04..64aef89a8 100644 --- a/jwa/jwa.go +++ b/jwa/jwa.go @@ -13,7 +13,7 @@ import "fmt" // like other fields. // // Ideally we would like to keep track of Signature Algorithms and -// Content Encryption Algorithms separately, and force the APIs to +// Key Encryption Algorithms separately, and force the APIs to // type-check at compile time, but this allows users to pass a value from a // jwk.Key directly type KeyAlgorithm interface { diff --git a/jwa/key_encryption_gen.go b/jwa/key_encryption_gen.go index db47f79f8..67fdc3928 100644 --- a/jwa/key_encryption_gen.go +++ b/jwa/key_encryption_gen.go @@ -37,6 +37,7 @@ const ( var muKeyEncryptionAlgorithms sync.RWMutex var allKeyEncryptionAlgorithms map[KeyEncryptionAlgorithm]struct{} var listKeyEncryptionAlgorithm []KeyEncryptionAlgorithm +var symmetricKeyEncryptionAlgorithms map[KeyEncryptionAlgorithm]struct{} func init() { muKeyEncryptionAlgorithms.Lock() @@ -61,16 +62,47 @@ func init() { allKeyEncryptionAlgorithms[RSA_OAEP_256] = struct{}{} allKeyEncryptionAlgorithms[RSA_OAEP_384] = struct{}{} allKeyEncryptionAlgorithms[RSA_OAEP_512] = struct{}{} + symmetricKeyEncryptionAlgorithms = make(map[KeyEncryptionAlgorithm]struct{}) + symmetricKeyEncryptionAlgorithms[A128GCMKW] = struct{}{} + symmetricKeyEncryptionAlgorithms[A128KW] = struct{}{} + symmetricKeyEncryptionAlgorithms[A192GCMKW] = struct{}{} + symmetricKeyEncryptionAlgorithms[A192KW] = struct{}{} + symmetricKeyEncryptionAlgorithms[A256GCMKW] = struct{}{} + symmetricKeyEncryptionAlgorithms[A256KW] = struct{}{} + symmetricKeyEncryptionAlgorithms[DIRECT] = struct{}{} + symmetricKeyEncryptionAlgorithms[PBES2_HS256_A128KW] = struct{}{} + symmetricKeyEncryptionAlgorithms[PBES2_HS384_A192KW] = struct{}{} + symmetricKeyEncryptionAlgorithms[PBES2_HS512_A256KW] = struct{}{} rebuildKeyEncryptionAlgorithm() } // RegisterKeyEncryptionAlgorithm registers a new KeyEncryptionAlgorithm so that the jwx can properly handle the new value. // Duplicates will silently be ignored func RegisterKeyEncryptionAlgorithm(v KeyEncryptionAlgorithm) { + RegisterKeyEncryptionAlgorithmWithOptions(v) +} + +// RegisterKeyEncryptionAlgorithmWithOptions is the same as RegisterKeyEncryptionAlgorithm when used without options, +// but allows its behavior to change based on the provided options. +// This is an experimental AND stopgap function which will most likely be merged in RegisterKeyEncryptionAlgorithm, and subsequently removed in the future. As such it should not be considered part of the stable API -- it is still subject to change. +// +// You can pass `WithSymmetricAlgorithm(true)` to let the library know that it's a symmetric algorithm. This library makes no attempt to verify if the algorithm is indeed symmetric or not. +func RegisterKeyEncryptionAlgorithmWithOptions(v KeyEncryptionAlgorithm, options ...RegisterAlgorithmOption) { + var symmetric bool + //nolint:forcetypeassert + for _, option := range options { + switch option.Ident() { + case identSymmetricAlgorithm{}: + symmetric = option.Value().(bool) + } + } muKeyEncryptionAlgorithms.Lock() defer muKeyEncryptionAlgorithms.Unlock() if _, ok := allKeyEncryptionAlgorithms[v]; !ok { allKeyEncryptionAlgorithms[v] = struct{}{} + if symmetric { + symmetricKeyEncryptionAlgorithms[v] = struct{}{} + } rebuildKeyEncryptionAlgorithm() } } @@ -82,6 +114,9 @@ func UnregisterKeyEncryptionAlgorithm(v KeyEncryptionAlgorithm) { defer muKeyEncryptionAlgorithms.Unlock() if _, ok := allKeyEncryptionAlgorithms[v]; ok { delete(allKeyEncryptionAlgorithms, v) + if _, ok := symmetricKeyEncryptionAlgorithms[v]; ok { + delete(symmetricKeyEncryptionAlgorithms, v) + } rebuildKeyEncryptionAlgorithm() } } @@ -134,11 +169,8 @@ func (v KeyEncryptionAlgorithm) String() string { return string(v) } -// IsSymmetric returns true if the algorithm is a symmetric type +// IsSymmetric returns true if the algorithm is a symmetric type. func (v KeyEncryptionAlgorithm) IsSymmetric() bool { - switch v { - case A128GCMKW, A128KW, A192GCMKW, A192KW, A256GCMKW, A256KW, DIRECT, PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW: - return true - } - return false + _, ok := symmetricKeyEncryptionAlgorithms[v] + return ok } diff --git a/jwa/key_encryption_gen_test.go b/jwa/key_encryption_gen_test.go index 0db993b87..ccb49d9e0 100644 --- a/jwa/key_encryption_gen_test.go +++ b/jwa/key_encryption_gen_test.go @@ -803,3 +803,174 @@ func TestKeyEncryptionAlgorithm(t *testing.T) { } }) } + +// Note: this test can NOT be run in parallel as it uses options with global effect. +func TestKeyEncryptionAlgorithmCustomAlgorithm(t *testing.T) { + // These subtests can NOT be run in parallel as options with global effect change. + customAlgorithm := jwa.KeyEncryptionAlgorithm("custom-algorithm") + // Unregister the custom algorithm, in case tests fail. + t.Cleanup(func() { + jwa.UnregisterKeyEncryptionAlgorithm(customAlgorithm) + }) + t.Run(`with custom algorithm registered`, func(t *testing.T) { + jwa.RegisterKeyEncryptionAlgorithm(customAlgorithm) + t.Run(`accept variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + if !assert.NoError(t, dst.Accept(customAlgorithm), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + if !assert.NoError(t, dst.Accept(`custom-algorithm`), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + if !assert.NoError(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`check symmetric`, func(t *testing.T) { + t.Parallel() + assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`) + }) + }) + t.Run(`with custom algorithm deregistered`, func(t *testing.T) { + jwa.UnregisterKeyEncryptionAlgorithm(customAlgorithm) + t.Run(`reject variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + assert.Error(t, dst.Accept(customAlgorithm), `accept failed`) + }) + t.Run(`reject the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + assert.Error(t, dst.Accept(`custom-algorithm`), `accept failed`) + }) + t.Run(`reject fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + assert.Error(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept failed`) + }) + t.Run(`check symmetric`, func(t *testing.T) { + t.Parallel() + assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`) + }) + }) + + t.Run(`with custom algorithm registered with WithSymmetricAlgorithm(false)`, func(t *testing.T) { + jwa.RegisterKeyEncryptionAlgorithmWithOptions(customAlgorithm, jwa.WithSymmetricAlgorithm(false)) + t.Run(`accept variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + if !assert.NoError(t, dst.Accept(customAlgorithm), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + if !assert.NoError(t, dst.Accept(`custom-algorithm`), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + if !assert.NoError(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`check symmetric`, func(t *testing.T) { + t.Parallel() + assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`) + }) + }) + t.Run(`with custom algorithm deregistered (was WithSymmetricAlgorithm(false))`, func(t *testing.T) { + jwa.UnregisterKeyEncryptionAlgorithm(customAlgorithm) + t.Run(`reject variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + assert.Error(t, dst.Accept(customAlgorithm), `accept failed`) + }) + t.Run(`reject the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + assert.Error(t, dst.Accept(`custom-algorithm`), `accept failed`) + }) + t.Run(`reject fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + assert.Error(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept failed`) + }) + t.Run(`check symmetric`, func(t *testing.T) { + t.Parallel() + assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`) + }) + }) + + t.Run(`with custom algorithm registered with WithSymmetricAlgorithm(true)`, func(t *testing.T) { + jwa.RegisterKeyEncryptionAlgorithmWithOptions(customAlgorithm, jwa.WithSymmetricAlgorithm(true)) + t.Run(`accept variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + if !assert.NoError(t, dst.Accept(customAlgorithm), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + if !assert.NoError(t, dst.Accept(`custom-algorithm`), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + if !assert.NoError(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`check symmetric`, func(t *testing.T) { + t.Parallel() + assert.True(t, customAlgorithm.IsSymmetric(), `custom algorithm should be symmetric`) + }) + }) + t.Run(`with custom algorithm deregistered (was WithSymmetricAlgorithm(true))`, func(t *testing.T) { + jwa.UnregisterKeyEncryptionAlgorithm(customAlgorithm) + t.Run(`reject variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + assert.Error(t, dst.Accept(customAlgorithm), `accept failed`) + }) + t.Run(`reject the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + assert.Error(t, dst.Accept(`custom-algorithm`), `accept failed`) + }) + t.Run(`reject fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyEncryptionAlgorithm + assert.Error(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept failed`) + }) + t.Run(`check symmetric`, func(t *testing.T) { + t.Parallel() + assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`) + }) + }) +} diff --git a/jwa/key_type_gen_test.go b/jwa/key_type_gen_test.go index 56df50642..e8cbca749 100644 --- a/jwa/key_type_gen_test.go +++ b/jwa/key_type_gen_test.go @@ -195,3 +195,58 @@ func TestKeyType(t *testing.T) { } }) } + +// Note: this test can NOT be run in parallel as it uses options with global effect. +func TestKeyTypeCustomAlgorithm(t *testing.T) { + // These subtests can NOT be run in parallel as options with global effect change. + customAlgorithm := jwa.KeyType("custom-algorithm") + // Unregister the custom algorithm, in case tests fail. + t.Cleanup(func() { + jwa.UnregisterKeyType(customAlgorithm) + }) + t.Run(`with custom algorithm registered`, func(t *testing.T) { + jwa.RegisterKeyType(customAlgorithm) + t.Run(`accept variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyType + if !assert.NoError(t, dst.Accept(customAlgorithm), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyType + if !assert.NoError(t, dst.Accept(`custom-algorithm`), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyType + if !assert.NoError(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + }) + t.Run(`with custom algorithm deregistered`, func(t *testing.T) { + jwa.UnregisterKeyType(customAlgorithm) + t.Run(`reject variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyType + assert.Error(t, dst.Accept(customAlgorithm), `accept failed`) + }) + t.Run(`reject the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyType + assert.Error(t, dst.Accept(`custom-algorithm`), `accept failed`) + }) + t.Run(`reject fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.KeyType + assert.Error(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept failed`) + }) + }) +} diff --git a/jwa/options.yaml b/jwa/options.yaml new file mode 100644 index 000000000..1b5596abf --- /dev/null +++ b/jwa/options.yaml @@ -0,0 +1,15 @@ +package_name: jwa +output: jwa/options_gen.go +interfaces: + - name: RegisterAlgorithmOption + comment: | + RegisterAlgorithmOption describes options that can be passed to the algorithm registering + functions that support options such as RegisterKeyEncryptionAlgorithmWithOptions. +options: + - ident: SymmetricAlgorithm + interface: RegisterAlgorithmOption + argument_type: bool + comment: | + WithSymmetricAlgorithm lets the library know whether the algorithm is symmetric. This affects + the response of the `IsSymmetric` method of the algorithm. If the algorithms does not support + this method, using this option will result in an error. diff --git a/jwa/options_gen.go b/jwa/options_gen.go new file mode 100644 index 000000000..3a4122f06 --- /dev/null +++ b/jwa/options_gen.go @@ -0,0 +1,33 @@ +// Code generated by tools/cmd/genoptions/main.go. DO NOT EDIT. + +package jwa + +import "github.com/lestrrat-go/option" + +type Option = option.Interface + +// RegisterAlgorithmOption describes options that can be passed to the algorithm registering +// functions that support options such as RegisterKeyEncryptionAlgorithmWithOptions. +type RegisterAlgorithmOption interface { + Option + registerAlgorithmOption() +} + +type registerAlgorithmOption struct { + Option +} + +func (*registerAlgorithmOption) registerAlgorithmOption() {} + +type identSymmetricAlgorithm struct{} + +func (identSymmetricAlgorithm) String() string { + return "WithSymmetricAlgorithm" +} + +// WithSymmetricAlgorithm lets the library know whether the algorithm is symmetric. This affects +// the response of the `IsSymmetric` method of the algorithm. If the algorithms does not support +// this method, using this option will result in an error. +func WithSymmetricAlgorithm(v bool) RegisterAlgorithmOption { + return ®isterAlgorithmOption{option.New(identSymmetricAlgorithm{}, v)} +} diff --git a/jwa/options_gen_test.go b/jwa/options_gen_test.go new file mode 100644 index 000000000..fefe5aec4 --- /dev/null +++ b/jwa/options_gen_test.go @@ -0,0 +1,13 @@ +// Code generated by tools/cmd/genoptions/main.go. DO NOT EDIT. + +package jwa + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOptionIdent(t *testing.T) { + require.Equal(t, "WithSymmetricAlgorithm", identSymmetricAlgorithm{}.String()) +} diff --git a/jwa/signature_gen.go b/jwa/signature_gen.go index eaa2f8662..5c08bade9 100644 --- a/jwa/signature_gen.go +++ b/jwa/signature_gen.go @@ -33,6 +33,7 @@ const ( var muSignatureAlgorithms sync.RWMutex var allSignatureAlgorithms map[SignatureAlgorithm]struct{} var listSignatureAlgorithm []SignatureAlgorithm +var symmetricSignatureAlgorithms map[SignatureAlgorithm]struct{} func init() { muSignatureAlgorithms.Lock() @@ -53,16 +54,40 @@ func init() { allSignatureAlgorithms[RS256] = struct{}{} allSignatureAlgorithms[RS384] = struct{}{} allSignatureAlgorithms[RS512] = struct{}{} + symmetricSignatureAlgorithms = make(map[SignatureAlgorithm]struct{}) + symmetricSignatureAlgorithms[HS256] = struct{}{} + symmetricSignatureAlgorithms[HS384] = struct{}{} + symmetricSignatureAlgorithms[HS512] = struct{}{} rebuildSignatureAlgorithm() } // RegisterSignatureAlgorithm registers a new SignatureAlgorithm so that the jwx can properly handle the new value. // Duplicates will silently be ignored func RegisterSignatureAlgorithm(v SignatureAlgorithm) { + RegisterSignatureAlgorithmWithOptions(v) +} + +// RegisterSignatureAlgorithmWithOptions is the same as RegisterSignatureAlgorithm when used without options, +// but allows its behavior to change based on the provided options. +// This is an experimental AND stopgap function which will most likely be merged in RegisterSignatureAlgorithm, and subsequently removed in the future. As such it should not be considered part of the stable API -- it is still subject to change. +// +// You can pass `WithSymmetricAlgorithm(true)` to let the library know that it's a symmetric algorithm. This library makes no attempt to verify if the algorithm is indeed symmetric or not. +func RegisterSignatureAlgorithmWithOptions(v SignatureAlgorithm, options ...RegisterAlgorithmOption) { + var symmetric bool + //nolint:forcetypeassert + for _, option := range options { + switch option.Ident() { + case identSymmetricAlgorithm{}: + symmetric = option.Value().(bool) + } + } muSignatureAlgorithms.Lock() defer muSignatureAlgorithms.Unlock() if _, ok := allSignatureAlgorithms[v]; !ok { allSignatureAlgorithms[v] = struct{}{} + if symmetric { + symmetricSignatureAlgorithms[v] = struct{}{} + } rebuildSignatureAlgorithm() } } @@ -74,6 +99,9 @@ func UnregisterSignatureAlgorithm(v SignatureAlgorithm) { defer muSignatureAlgorithms.Unlock() if _, ok := allSignatureAlgorithms[v]; ok { delete(allSignatureAlgorithms, v) + if _, ok := symmetricSignatureAlgorithms[v]; ok { + delete(symmetricSignatureAlgorithms, v) + } rebuildSignatureAlgorithm() } } @@ -125,3 +153,10 @@ func (v *SignatureAlgorithm) Accept(value interface{}) error { func (v SignatureAlgorithm) String() string { return string(v) } + +// IsSymmetric returns true if the algorithm is a symmetric type. +// Keep in mind that the NoSignature algorithm is neither a symmetric nor an asymmetric algorithm. +func (v SignatureAlgorithm) IsSymmetric() bool { + _, ok := symmetricSignatureAlgorithms[v] + return ok +} diff --git a/jwa/signature_gen_test.go b/jwa/signature_gen_test.go index 647111fd6..09cdf4f71 100644 --- a/jwa/signature_gen_test.go +++ b/jwa/signature_gen_test.go @@ -565,6 +565,54 @@ func TestSignatureAlgorithm(t *testing.T) { return } }) + t.Run(`check symmetric values`, func(t *testing.T) { + t.Parallel() + t.Run(`ES256`, func(t *testing.T) { + assert.False(t, jwa.ES256.IsSymmetric(), `jwa.ES256 should NOT be symmetric`) + }) + t.Run(`ES256K`, func(t *testing.T) { + assert.False(t, jwa.ES256K.IsSymmetric(), `jwa.ES256K should NOT be symmetric`) + }) + t.Run(`ES384`, func(t *testing.T) { + assert.False(t, jwa.ES384.IsSymmetric(), `jwa.ES384 should NOT be symmetric`) + }) + t.Run(`ES512`, func(t *testing.T) { + assert.False(t, jwa.ES512.IsSymmetric(), `jwa.ES512 should NOT be symmetric`) + }) + t.Run(`EdDSA`, func(t *testing.T) { + assert.False(t, jwa.EdDSA.IsSymmetric(), `jwa.EdDSA should NOT be symmetric`) + }) + t.Run(`HS256`, func(t *testing.T) { + assert.True(t, jwa.HS256.IsSymmetric(), `jwa.HS256 should be symmetric`) + }) + t.Run(`HS384`, func(t *testing.T) { + assert.True(t, jwa.HS384.IsSymmetric(), `jwa.HS384 should be symmetric`) + }) + t.Run(`HS512`, func(t *testing.T) { + assert.True(t, jwa.HS512.IsSymmetric(), `jwa.HS512 should be symmetric`) + }) + t.Run(`NoSignature`, func(t *testing.T) { + assert.False(t, jwa.NoSignature.IsSymmetric(), `jwa.NoSignature should NOT be symmetric`) + }) + t.Run(`PS256`, func(t *testing.T) { + assert.False(t, jwa.PS256.IsSymmetric(), `jwa.PS256 should NOT be symmetric`) + }) + t.Run(`PS384`, func(t *testing.T) { + assert.False(t, jwa.PS384.IsSymmetric(), `jwa.PS384 should NOT be symmetric`) + }) + t.Run(`PS512`, func(t *testing.T) { + assert.False(t, jwa.PS512.IsSymmetric(), `jwa.PS512 should NOT be symmetric`) + }) + t.Run(`RS256`, func(t *testing.T) { + assert.False(t, jwa.RS256.IsSymmetric(), `jwa.RS256 should NOT be symmetric`) + }) + t.Run(`RS384`, func(t *testing.T) { + assert.False(t, jwa.RS384.IsSymmetric(), `jwa.RS384 should NOT be symmetric`) + }) + t.Run(`RS512`, func(t *testing.T) { + assert.False(t, jwa.RS512.IsSymmetric(), `jwa.RS512 should NOT be symmetric`) + }) + }) t.Run(`check list of elements`, func(t *testing.T) { t.Parallel() var expected = map[jwa.SignatureAlgorithm]struct{}{ @@ -595,3 +643,174 @@ func TestSignatureAlgorithm(t *testing.T) { } }) } + +// Note: this test can NOT be run in parallel as it uses options with global effect. +func TestSignatureAlgorithmCustomAlgorithm(t *testing.T) { + // These subtests can NOT be run in parallel as options with global effect change. + customAlgorithm := jwa.SignatureAlgorithm("custom-algorithm") + // Unregister the custom algorithm, in case tests fail. + t.Cleanup(func() { + jwa.UnregisterSignatureAlgorithm(customAlgorithm) + }) + t.Run(`with custom algorithm registered`, func(t *testing.T) { + jwa.RegisterSignatureAlgorithm(customAlgorithm) + t.Run(`accept variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + if !assert.NoError(t, dst.Accept(customAlgorithm), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + if !assert.NoError(t, dst.Accept(`custom-algorithm`), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + if !assert.NoError(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`check symmetric`, func(t *testing.T) { + t.Parallel() + assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`) + }) + }) + t.Run(`with custom algorithm deregistered`, func(t *testing.T) { + jwa.UnregisterSignatureAlgorithm(customAlgorithm) + t.Run(`reject variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + assert.Error(t, dst.Accept(customAlgorithm), `accept failed`) + }) + t.Run(`reject the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + assert.Error(t, dst.Accept(`custom-algorithm`), `accept failed`) + }) + t.Run(`reject fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + assert.Error(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept failed`) + }) + t.Run(`check symmetric`, func(t *testing.T) { + t.Parallel() + assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`) + }) + }) + + t.Run(`with custom algorithm registered with WithSymmetricAlgorithm(false)`, func(t *testing.T) { + jwa.RegisterSignatureAlgorithmWithOptions(customAlgorithm, jwa.WithSymmetricAlgorithm(false)) + t.Run(`accept variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + if !assert.NoError(t, dst.Accept(customAlgorithm), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + if !assert.NoError(t, dst.Accept(`custom-algorithm`), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + if !assert.NoError(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`check symmetric`, func(t *testing.T) { + t.Parallel() + assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`) + }) + }) + t.Run(`with custom algorithm deregistered (was WithSymmetricAlgorithm(false))`, func(t *testing.T) { + jwa.UnregisterSignatureAlgorithm(customAlgorithm) + t.Run(`reject variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + assert.Error(t, dst.Accept(customAlgorithm), `accept failed`) + }) + t.Run(`reject the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + assert.Error(t, dst.Accept(`custom-algorithm`), `accept failed`) + }) + t.Run(`reject fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + assert.Error(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept failed`) + }) + t.Run(`check symmetric`, func(t *testing.T) { + t.Parallel() + assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`) + }) + }) + + t.Run(`with custom algorithm registered with WithSymmetricAlgorithm(true)`, func(t *testing.T) { + jwa.RegisterSignatureAlgorithmWithOptions(customAlgorithm, jwa.WithSymmetricAlgorithm(true)) + t.Run(`accept variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + if !assert.NoError(t, dst.Accept(customAlgorithm), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + if !assert.NoError(t, dst.Accept(`custom-algorithm`), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`accept fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + if !assert.NoError(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept is successful`) { + return + } + assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`) + }) + t.Run(`check symmetric`, func(t *testing.T) { + t.Parallel() + assert.True(t, customAlgorithm.IsSymmetric(), `custom algorithm should be symmetric`) + }) + }) + t.Run(`with custom algorithm deregistered (was WithSymmetricAlgorithm(true))`, func(t *testing.T) { + jwa.UnregisterSignatureAlgorithm(customAlgorithm) + t.Run(`reject variable used to register custom algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + assert.Error(t, dst.Accept(customAlgorithm), `accept failed`) + }) + t.Run(`reject the string custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + assert.Error(t, dst.Accept(`custom-algorithm`), `accept failed`) + }) + t.Run(`reject fmt.Stringer for custom-algorithm`, func(t *testing.T) { + t.Parallel() + var dst jwa.SignatureAlgorithm + assert.Error(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept failed`) + }) + t.Run(`check symmetric`, func(t *testing.T) { + t.Parallel() + assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`) + }) + }) +} diff --git a/tools/cmd/genjwa/main.go b/tools/cmd/genjwa/main.go index 18169f00b..cd42a9b52 100644 --- a/tools/cmd/genjwa/main.go +++ b/tools/cmd/genjwa/main.go @@ -149,9 +149,10 @@ func _main() error { }, }, { - name: `SignatureAlgorithm`, - comment: `SignatureAlgorithm represents the various signature algorithms as described in https://tools.ietf.org/html/rfc7518#section-3.1`, - filename: `signature_gen.go`, + name: `SignatureAlgorithm`, + comment: `SignatureAlgorithm represents the various signature algorithms as described in https://tools.ietf.org/html/rfc7518#section-3.1`, + filename: `signature_gen.go`, + symmetric: true, elements: []element{ { name: `NoSignature`, @@ -161,16 +162,19 @@ func _main() error { name: `HS256`, value: "HS256", comment: `HMAC using SHA-256`, + sym: true, }, { name: `HS384`, value: `HS384`, comment: `HMAC using SHA-384`, + sym: true, }, { name: `HS512`, value: "HS512", comment: `HMAC using SHA-512`, + sym: true, }, { name: `RS256`, @@ -230,9 +234,10 @@ func _main() error { }, }, { - name: `KeyEncryptionAlgorithm`, - comment: `KeyEncryptionAlgorithm represents the various encryption algorithms as described in https://tools.ietf.org/html/rfc7518#section-4.1`, - filename: `key_encryption_gen.go`, + name: `KeyEncryptionAlgorithm`, + comment: `KeyEncryptionAlgorithm represents the various encryption algorithms as described in https://tools.ietf.org/html/rfc7518#section-4.1`, + filename: `key_encryption_gen.go`, + symmetric: true, elements: []element{ { name: `RSA1_5`, @@ -263,21 +268,25 @@ func _main() error { name: `A128KW`, value: "A128KW", comment: `AES key wrap (128)`, + sym: true, }, { name: `A192KW`, value: "A192KW", comment: `AES key wrap (192)`, + sym: true, }, { name: `A256KW`, value: "A256KW", comment: `AES key wrap (256)`, + sym: true, }, { name: `DIRECT`, value: "dir", comment: `Direct encryption`, + sym: true, }, { name: `ECDH_ES`, @@ -303,31 +312,37 @@ func _main() error { name: `A128GCMKW`, value: "A128GCMKW", comment: `AES-GCM key wrap (128)`, + sym: true, }, { name: `A192GCMKW`, value: "A192GCMKW", comment: `AES-GCM key wrap (192)`, + sym: true, }, { name: `A256GCMKW`, value: "A256GCMKW", comment: `AES-GCM key wrap (256)`, + sym: true, }, { name: `PBES2_HS256_A128KW`, value: "PBES2-HS256+A128KW", comment: `PBES2 + HMAC-SHA256 + AES key wrap (128)`, + sym: true, }, { name: `PBES2_HS384_A192KW`, value: "PBES2-HS384+A192KW", comment: `PBES2 + HMAC-SHA384 + AES key wrap (192)`, + sym: true, }, { name: `PBES2_HS512_A256KW`, value: "PBES2-HS512+A256KW", comment: `PBES2 + HMAC-SHA512 + AES key wrap (256)`, + sym: true, }, }, }, @@ -353,10 +368,11 @@ func _main() error { } type typ struct { - name string - comment string - filename string - elements []element + name string + comment string + filename string + elements []element + symmetric bool } type element struct { @@ -364,20 +380,7 @@ type element struct { value string comment string invalid bool -} - -var isSymmetricKeyEncryption = map[string]struct{}{ - `A128KW`: {}, - `A192KW`: {}, - `A256KW`: {}, - `DIRECT`: {}, - `A128GCMKW`: {}, - `A192GCMKW`: {}, - `A256GCMKW`: {}, - - `PBES2_HS256_A128KW`: {}, - `PBES2_HS384_A192KW`: {}, - `PBES2_HS512_A256KW`: {}, + sym bool } func (t typ) Generate() error { @@ -391,6 +394,9 @@ func (t typ) Generate() error { o.LL("import (") pkgs := []string{ "fmt", + "sort", + "sync", + "strings", } for _, pkg := range pkgs { o.L("%s", strconv.Quote(pkg)) @@ -410,12 +416,15 @@ func (t typ) Generate() error { } o.L(")") // end const - // Register%s and related tools are provided so users can register their own types. + // Register and related tools are provided so users can register their own types. // This triggers some re-building of data structures that are otherwise // reused for efficiency o.LL("var mu%[1]ss sync.RWMutex", t.name) o.L("var all%[1]ss map[%[1]s]struct{}", t.name) o.L("var list%[1]s []%[1]s", t.name) + if t.symmetric { + o.L("var symmetric%[1]ss map[%[1]s]struct{}", t.name) + } o.LL("func init() {") o.L("mu%[1]ss.Lock()", t.name) @@ -426,19 +435,60 @@ func (t typ) Generate() error { o.L("all%[1]ss[%[2]s] = struct{}{}", t.name, e.name) } } + if t.symmetric { + o.L("symmetric%[1]ss = make(map[%[1]s]struct{})", t.name) + for _, e := range t.elements { + if !e.invalid && e.sym { + o.L("symmetric%[1]ss[%[2]s] = struct{}{}", t.name, e.name) + } + } + } o.L("rebuild%[1]s()", t.name) o.L("}") - o.LL("// Register%[1]s registers a new %[1]s so that the jwx can properly handle the new value.", t.name) - o.L("// Duplicates will silently be ignored") - o.L("func Register%[1]s(v %[1]s) {", t.name) - o.L("mu%[1]ss.Lock()", t.name) - o.L("defer mu%[1]ss.Unlock()", t.name) - o.L("if _, ok := all%[1]ss[v]; !ok {", t.name) - o.L("all%[1]ss[v] = struct{}{}", t.name) - o.L("rebuild%[1]s()", t.name) - o.L("}") - o.L("}") + if !t.symmetric { + o.LL("// Register%[1]s registers a new %[1]s so that the jwx can properly handle the new value.", t.name) + o.L("// Duplicates will silently be ignored") + o.L("func Register%[1]s(v %[1]s) {", t.name) + o.L("mu%[1]ss.Lock()", t.name) + o.L("defer mu%[1]ss.Unlock()", t.name) + o.L("if _, ok := all%[1]ss[v]; !ok {", t.name) + o.L("all%[1]ss[v] = struct{}{}", t.name) + o.L("rebuild%[1]s()", t.name) + o.L("}") + o.L("}") + } else { + o.LL("// Register%[1]s registers a new %[1]s so that the jwx can properly handle the new value.", t.name) + o.L("// Duplicates will silently be ignored") + o.L("func Register%[1]s(v %[1]s) {", t.name) + o.L("Register%[1]sWithOptions(v)", t.name) + o.L("}") + + o.LL("// Register%[1]sWithOptions is the same as Register%[1]s when used without options,", t.name) + o.L("// but allows its behavior to change based on the provided options.") + o.L("// This is an experimental AND stopgap function which will most likely be merged in Register%[1]s, and subsequently removed in the future. As such it should not be considered part of the stable API -- it is still subject to change.", t.name) + o.L("//") + o.L("// You can pass `WithSymmetricAlgorithm(true)` to let the library know that it's a symmetric algorithm. This library makes no attempt to verify if the algorithm is indeed symmetric or not.") + o.L("func Register%[1]sWithOptions(v %[1]s, options ...RegisterAlgorithmOption) {", t.name) + o.L("var symmetric bool") + o.L("//nolint:forcetypeassert") + o.L("for _, option := range options {") + o.L("switch option.Ident() {") + o.L("case identSymmetricAlgorithm{}:") + o.L("symmetric = option.Value().(bool)") + o.L("}") + o.L("}") + o.L("mu%[1]ss.Lock()", t.name) + o.L("defer mu%[1]ss.Unlock()", t.name) + o.L("if _, ok := all%[1]ss[v]; !ok {", t.name) + o.L("all%[1]ss[v] = struct{}{}", t.name) + o.L("if symmetric {") + o.L("symmetric%[1]ss[v] = struct{}{}", t.name) + o.L("}") + o.L("rebuild%[1]s()", t.name) + o.L("}") + o.L("}") + } o.LL("// Unregister%[1]s unregisters a %[1]s from its known database.", t.name) o.L("// Non-existentn entries will silently be ignored") @@ -447,6 +497,11 @@ func (t typ) Generate() error { o.L("defer mu%[1]ss.Unlock()", t.name) o.L("if _, ok := all%[1]ss[v]; ok {", t.name) o.L("delete(all%[1]ss, v)", t.name) + if t.symmetric { + o.L("if _, ok := symmetric%[1]ss[v]; ok {", t.name) + o.L("delete(symmetric%[1]ss, v)", t.name) + o.L("}") + } o.L("rebuild%[1]s()", t.name) o.L("}") o.L("}") @@ -500,27 +555,14 @@ func (t typ) Generate() error { o.L("return string(v)") o.L("}") - if t.name == "KeyEncryptionAlgorithm" { - o.LL("// IsSymmetric returns true if the algorithm is a symmetric type") - o.L("func (v %s) IsSymmetric() bool {", t.name) - o.L("switch v {") - o.L("case ") - var count int - for _, e := range t.elements { - if _, ok := isSymmetricKeyEncryption[e.name]; !ok { - continue - } - if count == 0 { - o.R("%s", e.name) - } else { - o.R(",%s", e.name) - } - count++ + if t.symmetric { + o.LL("// IsSymmetric returns true if the algorithm is a symmetric type.") + if t.name == "SignatureAlgorithm" { + o.L("// Keep in mind that the NoSignature algorithm is neither a symmetric nor an asymmetric algorithm.") } - o.R(":") - o.L("return true") - o.L("}") - o.L("return false") + o.L("func (v %s) IsSymmetric() bool {", t.name) + o.L("_, ok := symmetric%[1]ss[v]", t.name) + o.L("return ok") o.L("}") } @@ -631,12 +673,12 @@ func (t typ) GenerateTest() error { o.L("}") o.L("})") - if t.name == "KeyEncryptionAlgorithm" { + if t.symmetric { o.L("t.Run(`check symmetric values`, func(t *testing.T) {") o.L("t.Parallel()") for _, e := range t.elements { o.L("t.Run(`%s`, func(t *testing.T) {", e.name) - if _, ok := isSymmetricKeyEncryption[e.name]; ok { + if e.sym { o.L("assert.True(t, jwa.%[1]s.IsSymmetric(), `jwa.%[1]s should be symmetric`)", e.name) } else { o.L("assert.False(t, jwa.%[1]s.IsSymmetric(), `jwa.%[1]s should NOT be symmetric`)", e.name) @@ -672,7 +714,135 @@ func (t typ) GenerateTest() error { o.L("return") o.L("}") o.L("})") + o.L("}") + + o.LL("// Note: this test can NOT be run in parallel as it uses options with global effect.") + o.L("func Test%sCustomAlgorithm(t *testing.T) {", t.name) + o.L("// These subtests can NOT be run in parallel as options with global effect change.") + o.L(`customAlgorithm := jwa.%[1]s("custom-algorithm")`, t.name) + o.L("// Unregister the custom algorithm, in case tests fail.") + o.L("t.Cleanup(func() {") + o.L("jwa.Unregister%[1]s(customAlgorithm)", t.name) + o.L("})") + o.L("t.Run(`with custom algorithm registered`, func(t *testing.T) {") + o.L("jwa.Register%[1]s(customAlgorithm)", t.name) + o.L("t.Run(`accept variable used to register custom algorithm`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("var dst jwa.%[1]s", t.name) + o.L("if !assert.NoError(t, dst.Accept(customAlgorithm), `accept is successful`) {") + o.L("return") + o.L("}") + o.L("assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`)") + o.L("})") + o.L("t.Run(`accept the string custom-algorithm`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("var dst jwa.%[1]s", t.name) + o.L("if !assert.NoError(t, dst.Accept(`custom-algorithm`), `accept is successful`) {") + o.L("return") + o.L("}") + o.L("assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`)") + o.L("})") + o.L("t.Run(`accept fmt.Stringer for custom-algorithm`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("var dst jwa.%[1]s", t.name) + o.L("if !assert.NoError(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept is successful`) {") + o.L("return") + o.L("}") + o.L("assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`)") + o.L("})") + if t.symmetric { + o.L("t.Run(`check symmetric`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`)") + o.L("})") + } + o.L("})") + o.L("t.Run(`with custom algorithm deregistered`, func(t *testing.T) {") + o.L("jwa.Unregister%[1]s(customAlgorithm)", t.name) + o.L("t.Run(`reject variable used to register custom algorithm`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("var dst jwa.%[1]s", t.name) + o.L("assert.Error(t, dst.Accept(customAlgorithm), `accept failed`)") + o.L("})") + o.L("t.Run(`reject the string custom-algorithm`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("var dst jwa.%[1]s", t.name) + o.L("assert.Error(t, dst.Accept(`custom-algorithm`), `accept failed`)") + o.L("})") + o.L("t.Run(`reject fmt.Stringer for custom-algorithm`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("var dst jwa.%[1]s", t.name) + o.L("assert.Error(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept failed`)") + o.L("})") + if t.symmetric { + o.L("t.Run(`check symmetric`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`)") + o.L("})") + } + o.L("})") + if t.symmetric { + for _, value := range []bool{false, true} { + o.LL("t.Run(`with custom algorithm registered with WithSymmetricAlgorithm(%t)`, func(t *testing.T) {", value) + o.L("jwa.Register%[1]sWithOptions(customAlgorithm, jwa.WithSymmetricAlgorithm(%[2]t))", t.name, value) + o.L("t.Run(`accept variable used to register custom algorithm`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("var dst jwa.%[1]s", t.name) + o.L("if !assert.NoError(t, dst.Accept(customAlgorithm), `accept is successful`) {") + o.L("return") + o.L("}") + o.L("assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`)") + o.L("})") + o.L("t.Run(`accept the string custom-algorithm`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("var dst jwa.%[1]s", t.name) + o.L("if !assert.NoError(t, dst.Accept(`custom-algorithm`), `accept is successful`) {") + o.L("return") + o.L("}") + o.L("assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`)") + o.L("})") + o.L("t.Run(`accept fmt.Stringer for custom-algorithm`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("var dst jwa.%[1]s", t.name) + o.L("if !assert.NoError(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept is successful`) {") + o.L("return") + o.L("}") + o.L("assert.Equal(t, customAlgorithm, dst, `accepted value should be equal to variable`)") + o.L("})") + o.L("t.Run(`check symmetric`, func(t *testing.T) {") + o.L("t.Parallel()") + if value { + o.L("assert.True(t, customAlgorithm.IsSymmetric(), `custom algorithm should be symmetric`)") + } else { + o.L("assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`)") + } + o.L("})") + o.L("})") + o.L("t.Run(`with custom algorithm deregistered (was WithSymmetricAlgorithm(%t))`, func(t *testing.T) {", value) + o.L("jwa.Unregister%[1]s(customAlgorithm)", t.name) + o.L("t.Run(`reject variable used to register custom algorithm`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("var dst jwa.%[1]s", t.name) + o.L("assert.Error(t, dst.Accept(customAlgorithm), `accept failed`)") + o.L("})") + o.L("t.Run(`reject the string custom-algorithm`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("var dst jwa.%[1]s", t.name) + o.L("assert.Error(t, dst.Accept(`custom-algorithm`), `accept failed`)") + o.L("})") + o.L("t.Run(`reject fmt.Stringer for custom-algorithm`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("var dst jwa.%[1]s", t.name) + o.L("assert.Error(t, dst.Accept(stringer{src: `custom-algorithm`}), `accept failed`)") + o.L("})") + o.L("t.Run(`check symmetric`, func(t *testing.T) {") + o.L("t.Parallel()") + o.L("assert.False(t, customAlgorithm.IsSymmetric(), `custom algorithm should NOT be symmetric`)") + o.L("})") + o.L("})") + } + } o.L("}") filename := strings.Replace(t.filename, "_gen.go", "_gen_test.go", 1)