diff --git a/core/vm/privacy/ringct.go b/core/vm/privacy/ringct.go index cc5c70d6bc..9e70ddddf2 100644 --- a/core/vm/privacy/ringct.go +++ b/core/vm/privacy/ringct.go @@ -513,6 +513,11 @@ func Verify(sig *RingSignature, verifyMes bool) bool { for j := 0; j < ringsize; j++ { var l []byte for i := 0; i < numRing; i++ { + // Validate S[i][j] and C[j] + if !isValidScalar(S[i][j], curve) || !isValidScalar(C[j], curve) { + return false // Or handle the error as required + } + // calculate L[i][j] = s[i][j]*G + c[j]*Ring[i][j] px, py := curve.ScalarMult(rings[i][j].X, rings[i][j].Y, C[j].Bytes()) // px, py = c_i*P_i sx, sy := curve.ScalarBaseMult(S[i][j].Bytes()) // sx, sy = s[i]*G @@ -524,6 +529,11 @@ func Verify(sig *RingSignature, verifyMes bool) bool { // calculate R_i = s[i][j]*H_p(Ring[i][j]) + c[j]*I[j] px, py = curve.ScalarMult(image[i].X, image[i].Y, C[j].Bytes()) // px, py = c[i]*I hx, hy := HashPoint(rings[i][j]) + + // Validate S[i][j], hx, and hy + if !isValidScalar(S[i][j], curve) || !isValidScalar(hx, curve) || !isValidScalar(hy, curve) { + return false // Or handle the error as required + } //log.Info("H[i][j]", "i", i, "j", j, "x.input", common.Bytes2Hex(rings[i][j].X.Bytes()), "y.input", common.Bytes2Hex(rings[i][j].Y.Bytes())) //log.Info("H[i][j]", "i", i, "j", j, "x", common.Bytes2Hex(hx.Bytes()), "y", common.Bytes2Hex(hy.Bytes())) sx, sy = curve.ScalarMult(hx, hy, S[i][j].Bytes()) // sx, sy = s[i]*H_p(P[i]) @@ -549,6 +559,10 @@ func Verify(sig *RingSignature, verifyMes bool) bool { return bytes.Equal(sig.C.Bytes(), C[ringsize].Bytes()) } +func isValidScalar(scalar *big.Int, curve elliptic.Curve) bool { + return scalar.Sign() >= 0 && scalar.Cmp(curve.Params().N) < 0 +} + func Link(sig_a *RingSignature, sig_b *RingSignature) bool { for i := 0; i < len(sig_a.I); i++ { for j := 0; j < len(sig_b.I); j++ { diff --git a/core/vm/privacy/ringct_test.go b/core/vm/privacy/ringct_test.go index 0c6733b6db..faef37f86a 100644 --- a/core/vm/privacy/ringct_test.go +++ b/core/vm/privacy/ringct_test.go @@ -4,8 +4,10 @@ import ( "bytes" "encoding/binary" "fmt" + "math/big" "testing" + "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/stretchr/testify/assert" ) @@ -127,3 +129,46 @@ func TestPadTo32Bytes(t *testing.T) { assert.True(t, bytes.Equal(PadTo32Bytes(arr[10:40]), arr[8:40]), "Test PadTo32Bytes shorter than 32 bytes #9") assert.True(t, bytes.Equal(PadTo32Bytes(arr[10:41]), arr[9:41]), "Test PadTo32Bytes shorter than 32 bytes #10") } + +func TestCurveScalarMult(t *testing.T) { + curve := crypto.S256() + + x, y := curve.ScalarBaseMult(curve.Params().N.Bytes()) + if x == nil && y == nil { + fmt.Println("Scalar multiplication with base point returns nil when scalar is the scalar field") + } + + x2, y2 := curve.ScalarMult(new(big.Int).SetUint64(uint64(100)), new(big.Int).SetUint64(uint64(2)), curve.Params().N.Bytes()) + if x2 == nil && y2 == nil { + fmt.Println("Scalar multiplication with a point (not necessarily on curve) returns nil when scalar is the scalar field") + } +} + +func TestNilPointerDereferencePanic(t *testing.T) { + numRing := 5 + ringSize := 10 + s := 7 + rings, privkeys, m, err := GenerateMultiRingParams(numRing, ringSize, s) + + ringSig, err := Sign(m, rings, privkeys, s) + if err != nil { + fmt.Println("Failed to set up") + } + + ringSig.S[0][0] = curve.Params().N // change one sig to the scalar field + + sig, err := ringSig.Serialize() + if err != nil { + t.Error("Failed to Serialize input Ring signature") + } + + deserializedSig, err := Deserialize(sig) + if err != nil { + t.Error("Failed to Deserialize Ring signature") + } + + verified := Verify(deserializedSig, false) + if verified { + t.Error("Should failed to verify Ring signature as the signature is invalid") + } +}