diff --git a/trie/notary.go b/trie/notary.go deleted file mode 100644 index 51ec1658d5..0000000000 --- a/trie/notary.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2020 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package trie - -import ( - "github.com/XinFinOrg/XDPoSChain/ethdb" - "github.com/XinFinOrg/XDPoSChain/ethdb/memorydb" -) - -// keyValueNotary tracks which keys have been accessed through a key-value reader -// with te scope of verifying if certain proof datasets are maliciously bloated. -type keyValueNotary struct { - ethdb.KeyValueReader - reads map[string]struct{} -} - -// newKeyValueNotary wraps a key-value database with an access notary to track -// which items have bene accessed. -func newKeyValueNotary(db ethdb.KeyValueReader) *keyValueNotary { - return &keyValueNotary{ - KeyValueReader: db, - reads: make(map[string]struct{}), - } -} - -// Get retrieves an item from the underlying database, but also tracks it as an -// accessed slot for bloat checks. -func (k *keyValueNotary) Get(key []byte) ([]byte, error) { - k.reads[string(key)] = struct{}{} - return k.KeyValueReader.Get(key) -} - -// Accessed returns s snapshot of the original key-value store containing only the -// data accessed through the notary. -func (k *keyValueNotary) Accessed() ethdb.KeyValueStore { - db := memorydb.New() - for keystr := range k.reads { - key := []byte(keystr) - val, _ := k.KeyValueReader.Get(key) - db.Put(key, val) - } - return db -} diff --git a/trie/proof.go b/trie/proof.go index 8a8132c266..5c9bbeadb8 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -463,108 +463,91 @@ func hasRightElement(node node, key []byte) bool { // // Except returning the error to indicate the proof is valid or not, the function will // also return a flag to indicate whether there exists more accounts/slots in the trie. -func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (ethdb.KeyValueStore, bool, error) { +// +// Note: This method does not verify that the proof is of minimal form. If the input +// proofs are 'bloated' with neighbour leaves or random data, aside from the 'useful' +// data, then the proof will still be accepted. +func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (bool, error) { if len(keys) != len(values) { - return nil, false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values)) + return false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values)) } // Ensure the received batch is monotonic increasing. for i := 0; i < len(keys)-1; i++ { if bytes.Compare(keys[i], keys[i+1]) >= 0 { - return nil, false, errors.New("range is not monotonically increasing") + return false, errors.New("range is not monotonically increasing") } } - // Create a key-value notary to track which items from the given proof the - // range prover actually needed to verify the data - notary := newKeyValueNotary(proof) - // Special case, there is no edge proof at all. The given range is expected // to be the whole leaf-set in the trie. if proof == nil { - var ( - diskdb = memorydb.New() - tr = NewStackTrie(diskdb) - ) + tr := NewStackTrie(nil) for index, key := range keys { tr.TryUpdate(key, values[index]) } if have, want := tr.Hash(), rootHash; have != want { - return nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have) + return false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have) } - // Proof seems valid, serialize remaining nodes into the database - if _, err := tr.Commit(); err != nil { - return nil, false, err - } - return diskdb, false, nil // No more elements + return false, nil // No more elements } // Special case, there is a provided edge proof but zero key/value // pairs, ensure there are no more accounts / slots in the trie. if len(keys) == 0 { - root, val, err := proofToPath(rootHash, nil, firstKey, notary, true) + root, val, err := proofToPath(rootHash, nil, firstKey, proof, true) if err != nil { - return nil, false, err + return false, err } if val != nil || hasRightElement(root, firstKey) { - return nil, false, errors.New("more entries available") + return false, errors.New("more entries available") } - // Since the entire proof is a single path, we can construct a trie and a - // node database directly out of the inputs, no need to generate them - diskdb := notary.Accessed() - return diskdb, hasRightElement(root, firstKey), nil + return hasRightElement(root, firstKey), nil } // Special case, there is only one element and two edge keys are same. // In this case, we can't construct two edge paths. So handle it here. if len(keys) == 1 && bytes.Equal(firstKey, lastKey) { - root, val, err := proofToPath(rootHash, nil, firstKey, notary, false) + root, val, err := proofToPath(rootHash, nil, firstKey, proof, false) if err != nil { - return nil, false, err + return false, err } if !bytes.Equal(firstKey, keys[0]) { - return nil, false, errors.New("correct proof but invalid key") + return false, errors.New("correct proof but invalid key") } if !bytes.Equal(val, values[0]) { - return nil, false, errors.New("correct proof but invalid data") + return false, errors.New("correct proof but invalid data") } - // Since the entire proof is a single path, we can construct a trie and a - // node database directly out of the inputs, no need to generate them - diskdb := notary.Accessed() - return diskdb, hasRightElement(root, firstKey), nil + return hasRightElement(root, firstKey), nil } // Ok, in all other cases, we require two edge paths available. // First check the validity of edge keys. if bytes.Compare(firstKey, lastKey) >= 0 { - return nil, false, errors.New("invalid edge keys") + return false, errors.New("invalid edge keys") } // todo(rjl493456442) different length edge keys should be supported if len(firstKey) != len(lastKey) { - return nil, false, errors.New("inconsistent edge keys") + return false, errors.New("inconsistent edge keys") } // Convert the edge proofs to edge trie paths. Then we can // have the same tree architecture with the original one. // For the first edge proof, non-existent proof is allowed. - root, _, err := proofToPath(rootHash, nil, firstKey, notary, true) + root, _, err := proofToPath(rootHash, nil, firstKey, proof, true) if err != nil { - return nil, false, err + return false, err } // Pass the root node here, the second path will be merged // with the first one. For the last edge proof, non-existent // proof is also allowed. - root, _, err = proofToPath(rootHash, root, lastKey, notary, true) + root, _, err = proofToPath(rootHash, root, lastKey, proof, true) if err != nil { - return nil, false, err + return false, err } // Remove all internal references. All the removed parts should // be re-filled(or re-constructed) by the given leaves range. empty, err := unsetInternal(root, firstKey, lastKey) if err != nil { - return nil, false, err + return false, err } // Rebuild the trie with the leaf stream, the shape of trie // should be same with the original one. - var ( - diskdb = memorydb.New() - triedb = NewDatabase(diskdb) - ) - tr := &Trie{root: root, Db: triedb} + tr := &Trie{root: root, Db: NewDatabase(memorydb.New())} if empty { tr.root = nil } @@ -572,16 +555,9 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key tr.TryUpdate(key, values[index]) } if tr.Hash() != rootHash { - return nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash()) + return false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash()) } - // Proof seems valid, serialize all the nodes into the database - if _, _, err := tr.Commit(nil); err != nil { - return nil, false, err - } - if err := triedb.Commit(rootHash, false); err != nil { - return nil, false, err - } - return diskdb, hasRightElement(root, keys[len(keys)-1]), nil + return hasRightElement(root, keys[len(keys)-1]), nil } // get returns the child of the given Node. Return nil if the diff --git a/trie/proof_test.go b/trie/proof_test.go index d31e94ebcc..6c3c3977bb 100644 --- a/trie/proof_test.go +++ b/trie/proof_test.go @@ -177,7 +177,7 @@ func TestRangeProof(t *testing.T) { keys = append(keys, entries[i].k) vals = append(vals, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) + _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) if err != nil { t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) } @@ -228,7 +228,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) { keys = append(keys, entries[i].k) vals = append(vals, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) + _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) if err != nil { t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) } @@ -249,7 +249,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof) + _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof) if err != nil { t.Fatal("Failed to verify whole rang with non-existent edges") } @@ -284,7 +284,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof) + _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof) if err == nil { t.Fatalf("Expected to detect the error, got nil") } @@ -306,7 +306,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof) + _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof) if err == nil { t.Fatalf("Expected to detect the error, got nil") } @@ -330,7 +330,7 @@ func TestOneElementRangeProof(t *testing.T) { if err := trie.Prove(entries[start].k, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - _, _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -345,7 +345,7 @@ func TestOneElementRangeProof(t *testing.T) { if err := trie.Prove(entries[start].k, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -360,7 +360,7 @@ func TestOneElementRangeProof(t *testing.T) { if err := trie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -375,7 +375,7 @@ func TestOneElementRangeProof(t *testing.T) { if err := trie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) + _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -394,7 +394,7 @@ func TestOneElementRangeProof(t *testing.T) { if err := tinyTrie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof) + _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -416,7 +416,7 @@ func TestAllElementsProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil) + _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -429,7 +429,7 @@ func TestAllElementsProof(t *testing.T) { if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof) + _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -444,7 +444,7 @@ func TestAllElementsProof(t *testing.T) { if err := trie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof) + _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -477,7 +477,7 @@ func TestSingleSideRangeProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) + _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -513,7 +513,7 @@ func TestReverseSingleSideRangeProof(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof) + _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -585,7 +585,7 @@ func TestBadRangeProof(t *testing.T) { index = mrand.Intn(end - start) vals[index] = nil } - _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) + _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) if err == nil { t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1) } @@ -619,7 +619,7 @@ func TestGappedRangeProof(t *testing.T) { keys = append(keys, entries[i].k) vals = append(vals, entries[i].v) } - _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) + _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) if err == nil { t.Fatal("expect error, got nil") } @@ -646,7 +646,7 @@ func TestSameSideProofs(t *testing.T) { if err := trie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) + _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) if err == nil { t.Fatalf("Expected error, got nil") } @@ -662,7 +662,7 @@ func TestSameSideProofs(t *testing.T) { if err := trie.Prove(last, 0, proof); err != nil { t.Fatalf("Failed to prove the last node %v", err) } - _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) + _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) if err == nil { t.Fatalf("Expected error, got nil") } @@ -730,7 +730,7 @@ func TestHasRightElement(t *testing.T) { k = append(k, entries[i].k) v = append(v, entries[i].v) } - _, hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof) + hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof) if err != nil { t.Fatalf("Expected no error, got %v", err) } @@ -763,25 +763,19 @@ func TestEmptyRangeProof(t *testing.T) { if err := trie.Prove(first, 0, proof); err != nil { t.Fatalf("Failed to prove the first node %v", err) } - db, _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof) + _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof) if c.err && err == nil { t.Fatalf("Expected error, got nil") } if !c.err && err != nil { t.Fatalf("Expected no error, got %v", err) } - // If no error was returned, ensure the returned database contains - // the entire proof, since there's no value - if !c.err { - if memdb := db.(*memorydb.Database); memdb.Len() != proof.Len() { - t.Errorf("database entry count mismatch: have %d, want %d", memdb.Len(), proof.Len()) - } - } } } // TestBloatedProof tests a malicious proof, where the proof is more or less the -// whole trie. +// whole trie. Previously we didn't accept such packets, but the new APIs do, so +// lets leave this test as a bit weird, but present. func TestBloatedProof(t *testing.T) { // Use a small trie trie, kvs := nonRandomTrie(100) @@ -809,10 +803,8 @@ func TestBloatedProof(t *testing.T) { trie.Prove(keys[0], 0, want) trie.Prove(keys[len(keys)-1], 0, want) - db, _, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) - // The db should not contain anything of the bloated data - if used := db.(*memorydb.Database); used.Len() != want.Len() { - t.Fatalf("notary proof size mismatch: have %d, want %d", used.Len(), want.Len()) + if _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof); err != nil { + t.Fatalf("expected bloated proof to succeed, got %v", err) } } @@ -916,7 +908,7 @@ func benchmarkVerifyRangeProof(b *testing.B, size int) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof) + _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof) if err != nil { b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) } @@ -943,7 +935,7 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) { } b.ResetTimer() for i := 0; i < b.N; i++ { - _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, nil) + _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, nil) if err != nil { b.Fatalf("Expected no error, got %v", err) }