Skip to content

Commit

Permalink
Convert jwa constants to function calls returning objects (#1203)
Browse files Browse the repository at this point in the history
* Convert jwa constants to function calls returning objects

* Fix jwx_test.go

* Fix tests

* more test fixes

* convert secp2561k

* check for return value of ok

* fix more algorithm handling

* get on with the times and use go fmt instead of gofmt

* the lookup for was builtin algorithms

* typo

* Fix for secp256k1 related code

* Fix examples

* Fix es256k test

* fix es256k test

* fix benchmark

* fix cmd

* Fix option generation

---------

Co-authored-by: Daisuke Maki <[email protected]>
  • Loading branch information
lestrrat and Daisuke Maki authored Oct 6, 2024
1 parent 4727853 commit edde47b
Show file tree
Hide file tree
Showing 109 changed files with 2,771 additions and 2,958 deletions.
10 changes: 10 additions & 0 deletions Changes-v3.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ These are changes that are incompatible with the v2.x.x version.
`Get(string, interface{}) error`, where the second argument should be a pointer
to the storage destination of the field.

## JWA

* All string constants have been renamed to equivalent functions that return a struct.
* By default, only known algorithm names are accepted. For example, in our JWK tests,
there are tests that deal with "ECMR" algorithm, but this will now fail by default.
If you want this algorithm to succeed parsing, you need to call `jwa.RegisterXXXX`
functions before using them
* Previously, unmarshaling unquoted strings used to work (e.g. `var s = "RS256"`),
but now they must conform to the JSON standard and be quoted (e.g. `var s = strconv.Quote("RS256")`)

## JWS

* Iterators have been completely removed.
Expand Down
2 changes: 1 addition & 1 deletion bench/performance/jwt_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

func BenchmarkJWT(b *testing.B) {
alg := jwa.RS256
alg := jwa.RS256()

key, err := jwxtest.GenerateRsaJwk()
if err != nil {
Expand Down
35 changes: 26 additions & 9 deletions cmd/jwx/jwe.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,26 @@ func makeJweEncryptCmd() *cli.Command {
}

var keyenc jwa.KeyEncryptionAlgorithm
if err := keyenc.Accept(c.String("key-encryption")); err != nil {
return fmt.Errorf(`invalid key encryption algorithm: %w`, err)
{
v, ok := jwa.LookupKeyEncryptionAlgorithm(c.String("key-encryption"))
if !ok {
return fmt.Errorf(`invalid key encryption algorithm %q`, c.String("key-encryption"))
}
keyenc = v
}

var cntenc jwa.ContentEncryptionAlgorithm
if err := cntenc.Accept(c.String("content-encryption")); err != nil {
return fmt.Errorf(`invalid content encryption algorithm: %w`, err)
{
v, ok := jwa.LookupContentEncryptionAlgorithm(c.String("content-encryption"))
if !ok {
return fmt.Errorf(`invalid content encryption algorithm %q`, c.String("content-encryption"))
}
cntenc = v
}

compress := jwa.NoCompress
compress := jwa.NoCompress()
if c.Bool("compress") {
compress = jwa.Deflate
compress = jwa.Deflate()
}

keyset, err := getKeyFile(c.String("key"), c.String("key-format"))
Expand Down Expand Up @@ -157,8 +165,13 @@ func makeJweDecryptCmd() *cli.Command {

if keyencalg := c.String("key-encryption"); keyencalg != "" {
var keyenc jwa.KeyEncryptionAlgorithm
if err := keyenc.Accept(c.String("key-encryption")); err != nil {
return fmt.Errorf(`invalid key encryption algorithm: %w`, err)
{
v, ok := jwa.LookupKeyEncryptionAlgorithm(keyencalg)
if !ok {
return fmt.Errorf(`invalid key encryption algorithm %q`, keyencalg)
}
keyenc = v

}

// if we have an explicit key encryption algorithm, we don't have to
Expand All @@ -170,7 +183,11 @@ func makeJweDecryptCmd() *cli.Command {
decrypted = v
} else {
v, err := jwe.Decrypt(buf, jwe.WithKeyProvider(jwe.KeyProviderFunc(func(_ context.Context, sink jwe.KeySink, r jwe.Recipient, _ *jwe.Message) error {
sink.Key(r.Headers().Algorithm(), key)
alg, ok := r.Headers().Algorithm()
if !ok {
return fmt.Errorf(`failed to determine key encryption algorithm`)
}
sink.Key(alg, key)
return nil
})))
if err != nil {
Expand Down
37 changes: 25 additions & 12 deletions cmd/jwx/jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,26 @@ func makeJwkGenerateCmd() *cli.Command {

cmd.Action = func(c *cli.Context) error {
var rawkey interface{}
switch typ := jwa.KeyType(c.String("type")); typ {
case jwa.RSA:
typ, ok := jwa.LookupKeyType(c.String("type"))
if !ok {
return fmt.Errorf(`invalid key type %s`, c.String("type"))
}

switch typ {
case jwa.RSA():
v, err := rsa.GenerateKey(rand.Reader, c.Int("keysize"))
if err != nil {
return fmt.Errorf(`failed to generate rsa private key: %w`, err)
}
rawkey = v
case jwa.EC:
case jwa.EC():
var crvalg jwa.EllipticCurveAlgorithm
if err := crvalg.Accept(c.String("curve")); err != nil {
return fmt.Errorf(`invalid elliptic curve name %s: %w`, c.String("curve"), err)
{
v, ok := jwa.LookupEllipticCurveAlgorithm(c.String("curve"))
if !ok {
return fmt.Errorf(`invalid elliptic curve name %q`, c.String("curve"))
}
crvalg = v
}

crv, err := ourecdsa.CurveFromAlgorithm(crvalg)
Expand All @@ -153,32 +162,36 @@ func makeJwkGenerateCmd() *cli.Command {
return fmt.Errorf(`failed to generate ECDSA private key: %w`, err)
}
rawkey = v
case jwa.OctetSeq:
case jwa.OctetSeq():
octets := make([]byte, c.Int("keysize"))
io.ReadFull(rand.Reader, octets)

rawkey = octets
case jwa.OKP:
case jwa.OKP():
var crvalg jwa.EllipticCurveAlgorithm
if err := crvalg.Accept(c.String("curve")); err != nil {
return fmt.Errorf(`invalid elliptic curve name: %w`, err)
{
v, ok := jwa.LookupEllipticCurveAlgorithm(c.String("curve"))
if !ok {
return fmt.Errorf(`invalid elliptic curve name %q`, c.String("curve"))
}
crvalg = v
}

switch crvalg {
case jwa.Ed25519:
case jwa.Ed25519():
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return fmt.Errorf(`failed to generate ed25519 private key: %w`, err)
}
rawkey = priv
case jwa.X25519:
case jwa.X25519():
priv, err := ecdh.X25519().GenerateKey(rand.Reader)
if err != nil {
return fmt.Errorf(`failed to generate x25519 private key: %w`, err)
}
rawkey = priv
default:
return fmt.Errorf(`invalid elliptic curve for OKP: %s (expected %s/%s)`, crvalg, jwa.Ed25519, jwa.X25519)
return fmt.Errorf(`invalid elliptic curve for OKP: %s (expected %s/%s)`, crvalg, jwa.Ed25519(), jwa.X25519())
}
default:
return fmt.Errorf(`invalid key type %s`, typ)
Expand Down
35 changes: 12 additions & 23 deletions cmd/jwx/jws.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ func makeJwsParseCmd() *cli.Command {
buf, err := io.ReadAll(src)
if err != nil {
return fmt.Errorf(`failed to read data from source: %w`, err)
if err != nil {
return fmt.Errorf(`failed to read message: %w`, err)
}
}

msg, err := jws.Parse(buf)
Expand Down Expand Up @@ -153,9 +150,6 @@ func makeJwsVerifyCmd() *cli.Command {
buf, err := io.ReadAll(src)
if err != nil {
return fmt.Errorf(`failed to read data from source: %w`, err)
if err != nil {
return fmt.Errorf(`failed to verify message: %w`, err)
}
}

output, err := getOutput(c.String("output"))
Expand All @@ -172,13 +166,12 @@ func makeJwsVerifyCmd() *cli.Command {
}
} else {
var alg jwa.SignatureAlgorithm
givenalg := c.String("alg")
if givenalg == "" {
return fmt.Errorf(`option --alg must be given`)
}

if err := alg.Accept(givenalg); err != nil {
return fmt.Errorf(`invalid alg %s`, givenalg)
{
v, ok := jwa.LookupSignatureAlgorithm(c.String("alg"))
if !ok {
return fmt.Errorf(`invalid algorithm %s`, c.String("alg"))
}
alg = v
}

for i := 0; i < keyset.Len(); i++ {
Expand Down Expand Up @@ -243,19 +236,15 @@ func makeJwsSignCmd() *cli.Command {
buf, err := io.ReadAll(src)
if err != nil {
return fmt.Errorf(`failed to read data from source: %w`, err)
if err != nil {
return fmt.Errorf(`failed to sign message: %w`, err)
}
}

var alg jwa.SignatureAlgorithm
givenalg := c.String("alg")
if givenalg == "" {
return fmt.Errorf(`option --alg must be given`)
}

if err := alg.Accept(givenalg); err != nil {
return fmt.Errorf(`invalid alg %s`, givenalg)
{
v, ok := jwa.LookupSignatureAlgorithm(c.String("alg"))
if !ok {
return fmt.Errorf(`invalid algorithm %s`, c.String("alg"))
}
alg = v
}

// headers must go to WithKeySuboptions
Expand Down
4 changes: 2 additions & 2 deletions examples/jwe_decrypt_with_key_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import (

func ExampleJWE_VerifyWithKey() {
const payload = "Lorem ipsum"
encrypted, err := jwe.Encrypt([]byte(payload), jwe.WithKey(jwa.RSA_OAEP, jwkRSAPublicKey))
encrypted, err := jwe.Encrypt([]byte(payload), jwe.WithKey(jwa.RSA_OAEP(), jwkRSAPublicKey))
if err != nil {
fmt.Printf("failed to sign payload: %s\n", err)
return
}

decrypted, err := jwe.Decrypt(encrypted, jwe.WithKey(jwa.RSA_OAEP, jwkRSAPrivateKey))
decrypted, err := jwe.Decrypt(encrypted, jwe.WithKey(jwa.RSA_OAEP(), jwkRSAPrivateKey))
if err != nil {
fmt.Printf("failed to sign payload: %s\n", err)
return
Expand Down
4 changes: 2 additions & 2 deletions examples/jwe_decrypt_with_keyset_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func ExampleJWE_VerifyWithJWKSet() {
return
}
const payload = "Lorem ipsum"
encrypted, err := jwe.Encrypt([]byte(payload), jwe.WithKey(jwa.RSA_OAEP, privkey.PublicKey))
encrypted, err := jwe.Encrypt([]byte(payload), jwe.WithKey(jwa.RSA_OAEP(), privkey.PublicKey))
if err != nil {
fmt.Printf("failed to sign payload: %s\n", err)
return
Expand All @@ -32,7 +32,7 @@ func ExampleJWE_VerifyWithJWKSet() {
set.AddKey(k2)
// Add the real thing
k3, _ := jwk.Import(privkey)
k3.Set(jwk.AlgorithmKey, jwa.RSA_OAEP)
k3.Set(jwk.AlgorithmKey, jwa.RSA_OAEP())
set.AddKey(k3)

// Up to this point, you probably will replace with a simple jwk.Fetch()
Expand Down
4 changes: 2 additions & 2 deletions examples/jwe_encrypt_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ func ExampleJWE_Encrypt() {
}

const payload = `Lorem ipsum`
encrypted, err := jwe.Encrypt([]byte(payload), jwe.WithKey(jwa.RSA_OAEP, pubkey))
encrypted, err := jwe.Encrypt([]byte(payload), jwe.WithKey(jwa.RSA_OAEP(), pubkey))
if err != nil {
fmt.Printf("failed to encrypt payload: %s\n", err)
return
}

decrypted, err := jwe.Decrypt(encrypted, jwe.WithKey(jwa.RSA_OAEP, privkey))
decrypted, err := jwe.Decrypt(encrypted, jwe.WithKey(jwa.RSA_OAEP(), privkey))
if err != nil {
fmt.Printf("failed to decrypt payload: %s\n", err)
return
Expand Down
8 changes: 4 additions & 4 deletions examples/jwe_encrypt_json_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ func ExampleJWE_EncryptJSON() {
}

const payload = `Lorem ipsum`
encrypted, err := jwe.Encrypt([]byte(payload), jwe.WithJSON(), jwe.WithKey(jwa.RSA_OAEP, pubkey))
encrypted, err := jwe.Encrypt([]byte(payload), jwe.WithJSON(), jwe.WithKey(jwa.RSA_OAEP(), pubkey))
if err != nil {
fmt.Printf("failed to encrypt payload: %s\n", err)
return
}

decrypted, err := jwe.Decrypt(encrypted, jwe.WithKey(jwa.RSA_OAEP, privkey))
decrypted, err := jwe.Decrypt(encrypted, jwe.WithKey(jwa.RSA_OAEP(), privkey))
if err != nil {
fmt.Printf("failed to decrypt payload: %s\n", err)
return
Expand Down Expand Up @@ -72,7 +72,7 @@ func ExampleJWE_EncryptJSONMulti() {

options := []jwe.EncryptOption{jwe.WithJSON()}
for _, key := range pubkeys {
options = append(options, jwe.WithKey(jwa.RSA_OAEP, key))
options = append(options, jwe.WithKey(jwa.RSA_OAEP(), key))
}

const payload = `Lorem ipsum`
Expand All @@ -83,7 +83,7 @@ func ExampleJWE_EncryptJSONMulti() {
}

for _, key := range privkeys {
decrypted, err := jwe.Decrypt(encrypted, jwe.WithKey(jwa.RSA_OAEP, key))
decrypted, err := jwe.Decrypt(encrypted, jwe.WithKey(jwa.RSA_OAEP(), key))
if err != nil {
fmt.Printf("failed to decrypt payload: %s\n", err)
return
Expand Down
2 changes: 1 addition & 1 deletion examples/jwe_encrypt_with_headers_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func ExampleJWE_SignWithHeaders() {

hdrs := jwe.NewHeaders()
hdrs.Set(`x-example`, true)
encrypted, err := jwe.Encrypt([]byte(payload), jwe.WithKey(jwa.RSA_OAEP, privkey.PublicKey, jwe.WithPerRecipientHeaders(hdrs)))
encrypted, err := jwe.Encrypt([]byte(payload), jwe.WithKey(jwa.RSA_OAEP(), privkey.PublicKey, jwe.WithPerRecipientHeaders(hdrs)))
if err != nil {
fmt.Printf("failed to encrypt payload: %s\n", err)
return
Expand Down
8 changes: 4 additions & 4 deletions examples/jwe_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func exampleGenPayload() (*rsa.PrivateKey, []byte, error) {

payload := []byte("Lorem Ipsum")

encrypted, err := jwe.Encrypt(payload, jwe.WithKey(jwa.RSA1_5, &privkey.PublicKey), jwe.WithContentEncryption(jwa.A128CBC_HS256))
encrypted, err := jwe.Encrypt(payload, jwe.WithKey(jwa.RSA1_5(), &privkey.PublicKey), jwe.WithContentEncryption(jwa.A128CBC_HS256()))
if err != nil {
return nil, nil, err
}
Expand All @@ -34,7 +34,7 @@ func ExampleJWE_Decrypt() {
return
}

decrypted, err := jwe.Decrypt(encrypted, jwe.WithKey(jwa.RSA1_5, privkey))
decrypted, err := jwe.Decrypt(encrypted, jwe.WithKey(jwa.RSA1_5(), privkey))
if err != nil {
log.Printf("failed to decrypt: %s", err)
return
Expand Down Expand Up @@ -68,7 +68,7 @@ func ExampleJWE_ComplexDecrypt() {
protected.Set(`jwx-hints`, `foobar`) // in real life this would a more meaningful value
encrypted, err := jwe.Encrypt(
[]byte(payload),
jwe.WithKey(jwa.RSA_OAEP, privkey.PublicKey),
jwe.WithKey(jwa.RSA_OAEP(), privkey.PublicKey),
jwe.WithProtectedHeaders(protected),
)
if err != nil {
Expand All @@ -93,7 +93,7 @@ func ExampleJWE_ComplexDecrypt() {
// You may opt to set both the algorithm and key here as well.
// BUT BE CAREFUL so that you don't accidentally create a
// vulnerability
sink.Key(jwa.RSA_OAEP, privkey)
sink.Key(jwa.RSA_OAEP(), privkey)
return nil
}
}
Expand Down
10 changes: 5 additions & 5 deletions examples/jws_custom_signer_verifier_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func NewCirclEdDSAVerifier() (jws.Verifier, error) {
}

func (s CirclEdDSASignerVerifier) Algorithm() jwa.SignatureAlgorithm {
return jwa.EdDSA
return jwa.EdDSA()
}

func (s CirclEdDSASignerVerifier) Sign(payload []byte, keyif interface{}) ([]byte, error) {
Expand All @@ -47,8 +47,8 @@ func (s CirclEdDSASignerVerifier) Verify(payload []byte, signature []byte, keyif
func ExampleJWS_CustomSignerVerifier() {
// This example shows how to register external jws.Signer / jws.Verifier for
// a given algorithm.
jws.RegisterSigner(jwa.EdDSA, jws.SignerFactoryFn(NewCirclEdDSASigner))
jws.RegisterVerifier(jwa.EdDSA, jws.VerifierFactoryFn(NewCirclEdDSAVerifier))
jws.RegisterSigner(jwa.EdDSA(), jws.SignerFactoryFn(NewCirclEdDSASigner))
jws.RegisterVerifier(jwa.EdDSA(), jws.VerifierFactoryFn(NewCirclEdDSAVerifier))

pubkey, privkey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
Expand All @@ -57,13 +57,13 @@ func ExampleJWS_CustomSignerVerifier() {
}

const payload = "Lorem Ipsum"
signed, err := jws.Sign([]byte(payload), jws.WithKey(jwa.EdDSA, privkey))
signed, err := jws.Sign([]byte(payload), jws.WithKey(jwa.EdDSA(), privkey))
if err != nil {
fmt.Printf(`failed to generate signed message: %s`, err)
return
}

verified, err := jws.Verify(signed, jws.WithKey(jwa.EdDSA, pubkey))
verified, err := jws.Verify(signed, jws.WithKey(jwa.EdDSA(), pubkey))
if err != nil {
fmt.Printf(`failed to verify signed message: %s`, err)
return
Expand Down
Loading

0 comments on commit edde47b

Please sign in to comment.