fix: missing validation in scalar range check (#417)

This commit is contained in:
Banana-J 2024-02-09 19:56:27 +11:00 committed by GitHub
parent ea7d5cc9ad
commit f453ce8315
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 59 additions and 0 deletions

View file

@ -513,6 +513,11 @@ func Verify(sig *RingSignature, verifyMes bool) bool {
for j := 0; j < ringsize; j++ {
var l []byte
for i := 0; i < numRing; i++ {
// Validate S[i][j] and C[j]
if !isValidScalar(S[i][j], curve) || !isValidScalar(C[j], curve) {
return false // Or handle the error as required
}
// calculate L[i][j] = s[i][j]*G + c[j]*Ring[i][j]
px, py := curve.ScalarMult(rings[i][j].X, rings[i][j].Y, C[j].Bytes()) // px, py = c_i*P_i
sx, sy := curve.ScalarBaseMult(S[i][j].Bytes()) // sx, sy = s[i]*G
@ -524,6 +529,11 @@ func Verify(sig *RingSignature, verifyMes bool) bool {
// calculate R_i = s[i][j]*H_p(Ring[i][j]) + c[j]*I[j]
px, py = curve.ScalarMult(image[i].X, image[i].Y, C[j].Bytes()) // px, py = c[i]*I
hx, hy := HashPoint(rings[i][j])
// Validate S[i][j], hx, and hy
if !isValidScalar(S[i][j], curve) || !isValidScalar(hx, curve) || !isValidScalar(hy, curve) {
return false // Or handle the error as required
}
//log.Info("H[i][j]", "i", i, "j", j, "x.input", common.Bytes2Hex(rings[i][j].X.Bytes()), "y.input", common.Bytes2Hex(rings[i][j].Y.Bytes()))
//log.Info("H[i][j]", "i", i, "j", j, "x", common.Bytes2Hex(hx.Bytes()), "y", common.Bytes2Hex(hy.Bytes()))
sx, sy = curve.ScalarMult(hx, hy, S[i][j].Bytes()) // sx, sy = s[i]*H_p(P[i])
@ -549,6 +559,10 @@ func Verify(sig *RingSignature, verifyMes bool) bool {
return bytes.Equal(sig.C.Bytes(), C[ringsize].Bytes())
}
func isValidScalar(scalar *big.Int, curve elliptic.Curve) bool {
return scalar.Sign() >= 0 && scalar.Cmp(curve.Params().N) < 0
}
func Link(sig_a *RingSignature, sig_b *RingSignature) bool {
for i := 0; i < len(sig_a.I); i++ {
for j := 0; j < len(sig_b.I); j++ {

View file

@ -4,8 +4,10 @@ import (
"bytes"
"encoding/binary"
"fmt"
"math/big"
"testing"
"github.com/XinFinOrg/XDPoSChain/crypto"
"github.com/stretchr/testify/assert"
)
@ -127,3 +129,46 @@ func TestPadTo32Bytes(t *testing.T) {
assert.True(t, bytes.Equal(PadTo32Bytes(arr[10:40]), arr[8:40]), "Test PadTo32Bytes shorter than 32 bytes #9")
assert.True(t, bytes.Equal(PadTo32Bytes(arr[10:41]), arr[9:41]), "Test PadTo32Bytes shorter than 32 bytes #10")
}
func TestCurveScalarMult(t *testing.T) {
curve := crypto.S256()
x, y := curve.ScalarBaseMult(curve.Params().N.Bytes())
if x == nil && y == nil {
fmt.Println("Scalar multiplication with base point returns nil when scalar is the scalar field")
}
x2, y2 := curve.ScalarMult(new(big.Int).SetUint64(uint64(100)), new(big.Int).SetUint64(uint64(2)), curve.Params().N.Bytes())
if x2 == nil && y2 == nil {
fmt.Println("Scalar multiplication with a point (not necessarily on curve) returns nil when scalar is the scalar field")
}
}
func TestNilPointerDereferencePanic(t *testing.T) {
numRing := 5
ringSize := 10
s := 7
rings, privkeys, m, err := GenerateMultiRingParams(numRing, ringSize, s)
ringSig, err := Sign(m, rings, privkeys, s)
if err != nil {
fmt.Println("Failed to set up")
}
ringSig.S[0][0] = curve.Params().N // change one sig to the scalar field
sig, err := ringSig.Serialize()
if err != nil {
t.Error("Failed to Serialize input Ring signature")
}
deserializedSig, err := Deserialize(sig)
if err != nil {
t.Error("Failed to Deserialize Ring signature")
}
verified := Verify(deserializedSig, false)
if verified {
t.Error("Should failed to verify Ring signature as the signature is invalid")
}
}