From ee303c9f3df7d5f60e33ed32301f221adff77a2a Mon Sep 17 00:00:00 2001 From: Daniel Liu Date: Fri, 29 Nov 2024 16:08:31 +0800 Subject: [PATCH] crypto/ecies: improve concatKDF (#20836) --- crypto/ecies/ecies.go | 148 ++++++++++++------------------------- crypto/ecies/ecies_test.go | 37 +++++++--- crypto/ecies/params.go | 20 +++++ 3 files changed, 94 insertions(+), 111 deletions(-) diff --git a/crypto/ecies/ecies.go b/crypto/ecies/ecies.go index bb1c8d2ff4..738bb8f584 100644 --- a/crypto/ecies/ecies.go +++ b/crypto/ecies/ecies.go @@ -35,8 +35,8 @@ import ( "crypto/elliptic" "crypto/hmac" "crypto/subtle" + "encoding/binary" "errors" - "fmt" "hash" "io" "math/big" @@ -45,7 +45,6 @@ import ( var ( ErrImport = errors.New("ecies: failed to import key") ErrInvalidCurve = errors.New("ecies: invalid elliptic curve") - ErrInvalidParams = errors.New("ecies: invalid ECIES parameters") ErrInvalidPublicKey = errors.New("ecies: invalid public key") ErrSharedKeyIsPointAtInfinity = errors.New("ecies: shared key is point at infinity") ErrSharedKeyTooBig = errors.New("ecies: shared key params are too big") @@ -139,57 +138,39 @@ func (prv *PrivateKey) GenerateShared(pub *PublicKey, skLen, macLen int) (sk []b } var ( - ErrKeyDataTooLong = errors.New("ecies: can't supply requested key data") ErrSharedTooLong = errors.New("ecies: shared secret is too long") ErrInvalidMessage = errors.New("ecies: invalid message") ) -var ( - big2To32 = new(big.Int).Exp(big.NewInt(2), big.NewInt(32), nil) - big2To32M1 = new(big.Int).Sub(big2To32, big.NewInt(1)) -) - -func incCounter(ctr []byte) { - if ctr[3]++; ctr[3] != 0 { - return - } - if ctr[2]++; ctr[2] != 0 { - return - } - if ctr[1]++; ctr[1] != 0 { - return - } - if ctr[0]++; ctr[0] != 0 { - return - } -} - // NIST SP 800-56 Concatenation Key Derivation Function (see section 5.8.1). -func concatKDF(hash hash.Hash, z, s1 []byte, kdLen int) (k []byte, err error) { - if s1 == nil { - s1 = make([]byte, 0) - } - - reps := ((kdLen + 7) * 8) / (hash.BlockSize() * 8) - if big.NewInt(int64(reps)).Cmp(big2To32M1) > 0 { - fmt.Println(big2To32M1) - return nil, ErrKeyDataTooLong - } - - counter := []byte{0, 0, 0, 1} - k = make([]byte, 0) - - for i := 0; i <= reps; i++ { - hash.Write(counter) +func concatKDF(hash hash.Hash, z, s1 []byte, kdLen int) []byte { + counterBytes := make([]byte, 4) + k := make([]byte, 0, roundup(kdLen, hash.Size())) + for counter := uint32(1); len(k) < kdLen; counter++ { + binary.BigEndian.PutUint32(counterBytes, counter) + hash.Reset() + hash.Write(counterBytes) hash.Write(z) hash.Write(s1) - k = append(k, hash.Sum(nil)...) - hash.Reset() - incCounter(counter) + k = hash.Sum(k) } + return k[:kdLen] +} - k = k[:kdLen] - return +// roundup rounds size up to the next multiple of blocksize. +func roundup(size, blocksize int) int { + return size + blocksize - (size % blocksize) +} + +// deriveKeys creates the encryption and MAC keys using concatKDF. +func deriveKeys(hash hash.Hash, z, s1 []byte, keyLen int) (Ke, Km []byte) { + K := concatKDF(hash, z, s1, 2*keyLen) + Ke = K[:keyLen] + Km = K[keyLen:] + hash.Reset() + hash.Write(Km) + Km = hash.Sum(Km[:0]) + return Ke, Km } // messageTag computes the MAC of a message (called the tag) as per @@ -210,7 +191,6 @@ func generateIV(params *ECIESParams, rand io.Reader) (iv []byte, err error) { } // symEncrypt carries out CTR encryption using the block cipher specified in the -// parameters. func symEncrypt(rand io.Reader, params *ECIESParams, key, m []byte) (ct []byte, err error) { c, err := params.Cipher(key) if err != nil { @@ -250,36 +230,27 @@ func symDecrypt(params *ECIESParams, key, ct []byte) (m []byte, err error) { // ciphertext. s1 is fed into key derivation, s2 is fed into the MAC. If the // shared information parameters aren't being used, they should be nil. func Encrypt(rand io.Reader, pub *PublicKey, m, s1, s2 []byte) (ct []byte, err error) { - params := pub.Params - if params == nil { - if params = ParamsFromCurve(pub.Curve); params == nil { - err = ErrUnsupportedECIESParameters - return - } + params, err := pubkeyParams(pub) + if err != nil { + return nil, err } + R, err := GenerateKey(rand, pub.Curve, params) if err != nil { - return + return nil, err + } + + z, err := R.GenerateShared(pub, params.KeyLen, params.KeyLen) + if err != nil { + return nil, err } hash := params.Hash() - z, err := R.GenerateShared(pub, params.KeyLen, params.KeyLen) - if err != nil { - return - } - K, err := concatKDF(hash, z, s1, params.KeyLen+params.KeyLen) - if err != nil { - return - } - Ke := K[:params.KeyLen] - Km := K[params.KeyLen:] - hash.Write(Km) - Km = hash.Sum(nil) - hash.Reset() + Ke, Km := deriveKeys(hash, z, s1, params.KeyLen) em, err := symEncrypt(rand, params, Ke, m) if err != nil || len(em) <= params.BlockSize { - return + return nil, err } d := messageTag(params.Hash, Km, em, s2) @@ -289,7 +260,7 @@ func Encrypt(rand io.Reader, pub *PublicKey, m, s1, s2 []byte) (ct []byte, err e copy(ct, Rb) copy(ct[len(Rb):], em) copy(ct[len(Rb)+len(em):], d) - return + return ct, nil } // Decrypt decrypts an ECIES ciphertext. @@ -297,13 +268,11 @@ func (prv *PrivateKey) Decrypt(c, s1, s2 []byte) (m []byte, err error) { if len(c) == 0 { return nil, ErrInvalidMessage } - params := prv.PublicKey.Params - if params == nil { - if params = ParamsFromCurve(prv.PublicKey.Curve); params == nil { - err = ErrUnsupportedECIESParameters - return - } + params, err := pubkeyParams(&prv.PublicKey) + if err != nil { + return nil, err } + hash := params.Hash() var ( @@ -317,12 +286,10 @@ func (prv *PrivateKey) Decrypt(c, s1, s2 []byte) (m []byte, err error) { case 2, 3, 4: rLen = (prv.PublicKey.Curve.Params().BitSize + 7) / 4 if len(c) < (rLen + hLen + 1) { - err = ErrInvalidMessage - return + return nil, ErrInvalidMessage } default: - err = ErrInvalidPublicKey - return + return nil, ErrInvalidPublicKey } mStart = rLen @@ -332,36 +299,19 @@ func (prv *PrivateKey) Decrypt(c, s1, s2 []byte) (m []byte, err error) { R.Curve = prv.PublicKey.Curve R.X, R.Y = elliptic.Unmarshal(R.Curve, c[:rLen]) if R.X == nil { - err = ErrInvalidPublicKey - return - } - if !R.Curve.IsOnCurve(R.X, R.Y) { - err = ErrInvalidCurve - return + return nil, ErrInvalidPublicKey } z, err := prv.GenerateShared(R, params.KeyLen, params.KeyLen) if err != nil { - return + return nil, err } - - K, err := concatKDF(hash, z, s1, params.KeyLen+params.KeyLen) - if err != nil { - return - } - - Ke := K[:params.KeyLen] - Km := K[params.KeyLen:] - hash.Write(Km) - Km = hash.Sum(nil) - hash.Reset() + Ke, Km := deriveKeys(hash, z, s1, params.KeyLen) d := messageTag(params.Hash, Km, c[mStart:mEnd], s2) if subtle.ConstantTimeCompare(c[mEnd:], d) != 1 { - err = ErrInvalidMessage - return + return nil, ErrInvalidMessage } - m, err = symDecrypt(params, Ke, c[mStart:mEnd]) - return + return symDecrypt(params, Ke, c[mStart:mEnd]) } diff --git a/crypto/ecies/ecies_test.go b/crypto/ecies/ecies_test.go index 235c2b59a8..a1a7604ba8 100644 --- a/crypto/ecies/ecies_test.go +++ b/crypto/ecies/ecies_test.go @@ -42,17 +42,23 @@ import ( "github.com/XinFinOrg/XDPoSChain/crypto" ) -// Ensure the KDF generates appropriately sized keys. func TestKDF(t *testing.T) { - msg := []byte("Hello, world") - h := sha256.New() - - k, err := concatKDF(h, msg, nil, 64) - if err != nil { - t.Fatal(err) + tests := []struct { + length int + output []byte + }{ + {6, decode("858b192fa2ed")}, + {32, decode("858b192fa2ed4395e2bf88dd8d5770d67dc284ee539f12da8bceaa45d06ebae0")}, + {48, decode("858b192fa2ed4395e2bf88dd8d5770d67dc284ee539f12da8bceaa45d06ebae0700f1ab918a5f0413b8140f9940d6955")}, + {64, decode("858b192fa2ed4395e2bf88dd8d5770d67dc284ee539f12da8bceaa45d06ebae0700f1ab918a5f0413b8140f9940d6955f3467fd6672cce1024c5b1effccc0f61")}, } - if len(k) != 64 { - t.Fatalf("KDF: generated key is the wrong size (%d instead of 64\n", len(k)) + + for _, test := range tests { + h := sha256.New() + k := concatKDF(h, []byte("input"), nil, test.length) + if !bytes.Equal(k, test.output) { + t.Fatalf("KDF: generated key %x does not match expected output %x", k, test.output) + } } } @@ -293,8 +299,8 @@ func TestParamSelection(t *testing.T) { func testParamSelection(t *testing.T, c testCase) { params := ParamsFromCurve(c.Curve) - if params == nil && c.Expected != nil { - t.Fatalf("%s (%s)\n", ErrInvalidParams.Error(), c.Name) + if params == nil { + t.Fatal("ParamsFromCurve returned nil") } else if params != nil && !cmpParams(params, c.Expected) { t.Fatalf("ecies: parameters should be invalid (%s)\n", c.Name) } @@ -328,7 +334,6 @@ func testParamSelection(t *testing.T, c testCase) { if err == nil { t.Fatalf("ecies: encryption should not have succeeded (%s)\n", c.Name) } - } // Ensure that the basic public key validation in the decryption operation @@ -414,3 +419,11 @@ func hexKey(prv string) *PrivateKey { } return ImportECDSA(key) } + +func decode(s string) []byte { + bytes, err := hex.DecodeString(s) + if err != nil { + panic(err) + } + return bytes +} diff --git a/crypto/ecies/params.go b/crypto/ecies/params.go index bd5969f1a9..293f2e83b1 100644 --- a/crypto/ecies/params.go +++ b/crypto/ecies/params.go @@ -40,6 +40,7 @@ import ( "crypto/sha256" "crypto/sha512" "errors" + "fmt" "hash" ethcrypto "github.com/XinFinOrg/XDPoSChain/crypto" @@ -49,8 +50,14 @@ var ( DefaultCurve = ethcrypto.S256() ErrUnsupportedECDHAlgorithm = errors.New("ecies: unsupported ECDH algorithm") ErrUnsupportedECIESParameters = errors.New("ecies: unsupported ECIES parameters") + ErrInvalidKeyLen = fmt.Errorf("ecies: invalid key size (> %d) in ECIESParams", maxKeyLen) ) +// KeyLen is limited to prevent overflow of the counter +// in concatKDF. While the theoretical limit is much higher, +// no known cipher uses keys larger than 512 bytes. +const maxKeyLen = 512 + type ECIESParams struct { Hash func() hash.Hash // hash function hashAlgo crypto.Hash @@ -115,3 +122,16 @@ func AddParamsForCurve(curve elliptic.Curve, params *ECIESParams) { func ParamsFromCurve(curve elliptic.Curve) (params *ECIESParams) { return paramsFromCurve[curve] } + +func pubkeyParams(key *PublicKey) (*ECIESParams, error) { + params := key.Params + if params == nil { + if params = ParamsFromCurve(key.Curve); params == nil { + return nil, ErrUnsupportedECIESParameters + } + } + if params.KeyLen > maxKeyLen { + return nil, ErrInvalidKeyLen + } + return params, nil +}