Skip to content

Commit

Permalink
deduplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Dec 20, 2024
1 parent be8db59 commit 74f51cc
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 113 deletions.
142 changes: 82 additions & 60 deletions cng/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ import (
"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
)

// maxHashSize is the size of SHA52 and SHA3_512, the largest hashes we support.
const maxHashSize = 64

// SupportsHash returns true if a hash.Hash implementation is supported for h.
func SupportsHash(h crypto.Hash) bool {
switch h {
Expand Down Expand Up @@ -145,11 +148,11 @@ func hashToID(h hash.Hash) string {
return hx.alg.id
}

// hashX implements [hash.Hash].
type hashX struct {
alg *hashAlgorithm
_ctx bcrypt.HASH_HANDLE // access it using withCtx
alg *hashAlgorithm
ctx bcrypt.HASH_HANDLE

buf []byte
key []byte
}

Expand All @@ -160,37 +163,34 @@ func newHashX(id string, flag bcrypt.AlgorithmProviderFlags, key []byte) *hashX
panic(err)
}
h := &hashX{alg: alg, key: bytes.Clone(key)}
// Don't allocate hx.buf nor call bcrypt.CreateHash yet,
// which would be wasteful if the caller only wants to know
// the hash type. This is a common pattern in this package,
// as some functions accept a `func() hash.Hash` parameter
// and call it just to know the hash type.
runtime.SetFinalizer(h, (*hashX).finalize)
// Don't call bcrypt.CreateHash yet, it would be wasteful
// if the caller only wants to know the hash type. This
// is a common pattern in this package, as some functions
// accept a `func() hash.Hash` parameter and call it just
// to know the hash type.
return h
}

func (h *hashX) finalize() {
if h._ctx != 0 {
bcrypt.DestroyHash(h._ctx)
}
bcrypt.DestroyHash(h.ctx)
}

func (h *hashX) withCtx(fn func(ctx bcrypt.HASH_HANDLE) error) error {
func (h *hashX) init() {
defer runtime.KeepAlive(h)
if h._ctx == 0 {
err := bcrypt.CreateHash(h.alg.handle, &h._ctx, nil, h.key, 0)
if err != nil {
panic(err)
}
if h.ctx != 0 {
return
}
err := bcrypt.CreateHash(h.alg.handle, &h.ctx, nil, h.key, bcrypt.HASH_REUSABLE_FLAG)
if err != nil {
panic(err)
}
return fn(h._ctx)
runtime.SetFinalizer(h, (*hashX).finalize)
}

func (h *hashX) Clone() (hash.Hash, error) {
defer runtime.KeepAlive(h)
h2 := &hashX{alg: h.alg, key: bytes.Clone(h.key)}
err := h.withCtx(func(ctx bcrypt.HASH_HANDLE) error {
return bcrypt.DuplicateHash(ctx, &h2._ctx, nil, 0)
})
err := bcrypt.DuplicateHash(h.ctx, &h2.ctx, nil, 0)
if err != nil {
return nil, err
}
Expand All @@ -199,49 +199,37 @@ func (h *hashX) Clone() (hash.Hash, error) {
}

func (h *hashX) Reset() {
if h._ctx != 0 {
bcrypt.DestroyHash(h._ctx)
h._ctx = 0
defer runtime.KeepAlive(h)
if h.ctx != 0 {
hashReset(h.ctx, h.Size())
}
}

func (h *hashX) Write(p []byte) (n int, err error) {
err = h.withCtx(func(ctx bcrypt.HASH_HANDLE) error {
for n < len(p) && err == nil {
nn := len32(p[n:])
err = bcrypt.HashData(h._ctx, p[n:n+nn], 0)
n += nn
}
return err
})
if err != nil {
// hash.Hash interface mandates Write should never return an error.
panic(err)
}
defer runtime.KeepAlive(h)
h.init()
hashData(h.ctx, p)
return len(p), nil
}

func (h *hashX) WriteString(s string) (int, error) {
// TODO: use unsafe.StringData once we drop support
// for go1.19 and earlier.
hdr := (*struct {
Data *byte
Len int
})(unsafe.Pointer(&s))
return h.Write(unsafe.Slice(hdr.Data, len(s)))
defer runtime.KeepAlive(h)
return h.Write(unsafe.Slice(unsafe.StringData(s), len(s)))
}

func (h *hashX) WriteByte(c byte) error {
err := h.withCtx(func(ctx bcrypt.HASH_HANDLE) error {
return bcrypt.HashDataRaw(h._ctx, &c, 1, 0)
})
if err != nil {
// hash.Hash interface mandates Write should never return an error.
panic(err)
}
defer runtime.KeepAlive(h)
h.init()
hashByte(h.ctx, c)
return nil
}

func (h *hashX) Sum(in []byte) []byte {
defer runtime.KeepAlive(h)
h.init()
return hashSum(h.ctx, h.Size(), in)
}

func (h *hashX) Size() int {
return int(h.alg.size)
}
Expand All @@ -250,21 +238,55 @@ func (h *hashX) BlockSize() int {
return int(h.alg.blockSize)
}

func (h *hashX) Sum(in []byte) []byte {
// hashData writes p to ctx. It panics on error.
func hashData(ctx bcrypt.HASH_HANDLE, p []byte) {
var n int
var err error
for n < len(p) && err == nil {
nn := len32(p[n:])
err = bcrypt.HashData(ctx, p[n:n+nn], 0)
n += nn
}
if err != nil {
panic(err)
}
}

// hashByte writes c to ctx. It panics on error.
func hashByte(ctx bcrypt.HASH_HANDLE, c byte) {
err := bcrypt.HashDataRaw(ctx, &c, 1, 0)
if err != nil {
panic(err)
}
}

// hashSum writes the hash of ctx to in and returns the result.
// size is the size of the hash output.
// It panics on error.
func hashSum(ctx bcrypt.HASH_HANDLE, size int, in []byte) []byte {
var ctx2 bcrypt.HASH_HANDLE
err := h.withCtx(func(ctx bcrypt.HASH_HANDLE) error {
return bcrypt.DuplicateHash(ctx, &ctx2, nil, 0)
})
err := bcrypt.DuplicateHash(ctx, &ctx2, nil, 0)
if err != nil {
panic(err)
}
defer bcrypt.DestroyHash(ctx2)
if h.buf == nil {
h.buf = make([]byte, h.alg.size)
}
err = bcrypt.FinishHash(ctx2, h.buf, 0)
buf := make([]byte, size, maxHashSize) // explicit cap to allow stack allocation
err = bcrypt.FinishHash(ctx2, buf, 0)
if err != nil {
panic(err)
}
return append(in, h.buf...)
return append(in, buf...)
}

// hashReset resets the hash state of ctx.
// size is the size of the hash output.
// It panics on error.
func hashReset(ctx bcrypt.HASH_HANDLE, size int) {
// bcrypt.FinishHash expects the output buffer to match the hash size.
// We don't care about the output, so we just pass a stack-allocated buffer
// that is large enough to hold the largest hash size we support.
var discard [maxHashSize]byte
if err := bcrypt.FinishHash(ctx, discard[:size], 0); err != nil {
panic(err)
}
}
66 changes: 13 additions & 53 deletions cng/sha3.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ import (
"github.com/microsoft/go-crypto-winnative/internal/bcrypt"
)

// maxSHA3Size is the size of SHA3_512, the largest SHA3 hash we support.
const maxSHA3Size = 64

// SumSHA3_256 returns the SHA3-256 checksum of the data.
func SumSHA3_256(p []byte) (sum [32]byte) {
if err := hashOneShot(bcrypt.SHA3_256_ALGORITHM, p, sum[:]); err != nil {
Expand Down Expand Up @@ -123,28 +120,14 @@ func (h *DigestSHA3) Clone() (hash.Hash, error) {
func (h *DigestSHA3) Reset() {
defer runtime.KeepAlive(h)
if h.ctx != 0 {
// bcrypt.FinishHash expects the output buffer to match the hash size.
// We don't care about the output, so we just pass a stack-allocated buffer
// that is large enough to hold the largest hash size we support.
var discard [maxSHA3Size]byte
if err := bcrypt.FinishHash(h.ctx, discard[:h.Size()], 0); err != nil {
panic(err)
}
hashReset(h.ctx, h.Size())
}
}

func (h *DigestSHA3) Write(p []byte) (n int, err error) {
defer runtime.KeepAlive(h)
h.init()
for n < len(p) && err == nil {
nn := len32(p[n:])
err = bcrypt.HashData(h.ctx, p[n:n+nn], 0)
n += nn
}
if err != nil {
// hash.Hash interface mandates Write should never return an error.
panic(err)
}
hashData(h.ctx, p)
return len(p), nil
}

Expand All @@ -156,14 +139,16 @@ func (h *DigestSHA3) WriteString(s string) (int, error) {
func (h *DigestSHA3) WriteByte(c byte) error {
defer runtime.KeepAlive(h)
h.init()
err := bcrypt.HashDataRaw(h.ctx, &c, 1, 0)
if err != nil {
// hash.Hash interface mandates Write should never return an error.
panic(err)
}
hashByte(h.ctx, c)
return nil
}

func (h *DigestSHA3) Sum(in []byte) []byte {
defer runtime.KeepAlive(h)
h.init()
return hashSum(h.ctx, h.Size(), in)
}

func (h *DigestSHA3) Size() int {
return int(h.alg.size)
}
Expand All @@ -172,23 +157,6 @@ func (h *DigestSHA3) BlockSize() int {
return int(h.alg.blockSize)
}

func (h *DigestSHA3) Sum(in []byte) []byte {
defer runtime.KeepAlive(h)
h.init()
var ctx2 bcrypt.HASH_HANDLE
err := bcrypt.DuplicateHash(h.ctx, &ctx2, nil, 0)
if err != nil {
panic(err)
}
defer bcrypt.DestroyHash(ctx2)
buf := make([]byte, h.alg.size, maxSHA3Size) // explicit cap to allow stack allocation
err = bcrypt.FinishHash(ctx2, buf, 0)
if err != nil {
panic(err)
}
return append(in, buf...)
}

// NewSHA3_256 returns a new SHA256 hash.
func NewSHA3_256() *DigestSHA3 {
return newDigestSHA3(bcrypt.SHA3_256_ALGORITHM)
Expand Down Expand Up @@ -281,14 +249,7 @@ func (s *SHAKE) Write(p []byte) (n int, err error) {
return 0, nil
}
defer runtime.KeepAlive(s)
for n < len(p) && err == nil {
nn := len32(p[n:])
err = bcrypt.HashData(s.ctx, p[n:n+nn], 0)
n += nn
}
if err != nil {
panic(err)
}
hashData(s.ctx, p)
return len(p), nil
}

Expand All @@ -314,10 +275,9 @@ func (s *SHAKE) Read(p []byte) (n int, err error) {
// Reset resets the XOF to its initial state.
func (s *SHAKE) Reset() {
defer runtime.KeepAlive(s)
var discard [1]byte
if err := bcrypt.FinishHash(s.ctx, discard[:], 0); err != nil {
panic(err)
}
// SHAKE has a variable size, CNG doesn't change the size of the hash
// when resetting, so we can pass a small value here.
hashReset(s.ctx, 1)
}

// BlockSize returns the rate of the XOF.
Expand Down

0 comments on commit 74f51cc

Please sign in to comment.