trie: remove redundant returns + use stacktrie where applicable #22760 (#1066)

* trie: add benchmark for proofless range

* trie: remove unused returns + use stacktrie

Co-authored-by: Martin Holst Swende <martin@swende.se>
This commit is contained in:
Daniel Liu 2025-07-28 16:49:31 +08:00 committed by GitHub
parent db9c3de1dc
commit f552cebfcd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 91 additions and 80 deletions

View file

@ -21,17 +21,17 @@ import (
"github.com/XinFinOrg/XDPoSChain/ethdb/memorydb" "github.com/XinFinOrg/XDPoSChain/ethdb/memorydb"
) )
// KeyValueNotary tracks which keys have been accessed through a key-value reader // keyValueNotary tracks which keys have been accessed through a key-value reader
// with te scope of verifying if certain proof datasets are maliciously bloated. // with te scope of verifying if certain proof datasets are maliciously bloated.
type KeyValueNotary struct { type keyValueNotary struct {
ethdb.KeyValueReader ethdb.KeyValueReader
reads map[string]struct{} reads map[string]struct{}
} }
// NewKeyValueNotary wraps a key-value database with an access notary to track // newKeyValueNotary wraps a key-value database with an access notary to track
// which items have bene accessed. // which items have bene accessed.
func NewKeyValueNotary(db ethdb.KeyValueReader) *KeyValueNotary { func newKeyValueNotary(db ethdb.KeyValueReader) *keyValueNotary {
return &KeyValueNotary{ return &keyValueNotary{
KeyValueReader: db, KeyValueReader: db,
reads: make(map[string]struct{}), reads: make(map[string]struct{}),
} }
@ -39,14 +39,14 @@ func NewKeyValueNotary(db ethdb.KeyValueReader) *KeyValueNotary {
// Get retrieves an item from the underlying database, but also tracks it as an // Get retrieves an item from the underlying database, but also tracks it as an
// accessed slot for bloat checks. // accessed slot for bloat checks.
func (k *KeyValueNotary) Get(key []byte) ([]byte, error) { func (k *keyValueNotary) Get(key []byte) ([]byte, error) {
k.reads[string(key)] = struct{}{} k.reads[string(key)] = struct{}{}
return k.KeyValueReader.Get(key) return k.KeyValueReader.Get(key)
} }
// Accessed returns s snapshot of the original key-value store containing only the // Accessed returns s snapshot of the original key-value store containing only the
// data accessed through the notary. // data accessed through the notary.
func (k *KeyValueNotary) Accessed() ethdb.KeyValueStore { func (k *keyValueNotary) Accessed() ethdb.KeyValueStore {
db := memorydb.New() db := memorydb.New()
for keystr := range k.reads { for keystr := range k.reads {
key := []byte(keystr) key := []byte(keystr)

View file

@ -463,115 +463,100 @@ func hasRightElement(node node, key []byte) bool {
// //
// Except returning the error to indicate the proof is valid or not, the function will // Except returning the error to indicate the proof is valid or not, the function will
// also return a flag to indicate whether there exists more accounts/slots in the trie. // also return a flag to indicate whether there exists more accounts/slots in the trie.
func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (ethdb.KeyValueStore, *Trie, *KeyValueNotary, bool, error) { func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (ethdb.KeyValueStore, bool, error) {
if len(keys) != len(values) { if len(keys) != len(values) {
return nil, nil, nil, false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values)) return nil, false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values))
} }
// Ensure the received batch is monotonic increasing. // Ensure the received batch is monotonic increasing.
for i := 0; i < len(keys)-1; i++ { for i := 0; i < len(keys)-1; i++ {
if bytes.Compare(keys[i], keys[i+1]) >= 0 { if bytes.Compare(keys[i], keys[i+1]) >= 0 {
return nil, nil, nil, false, errors.New("range is not monotonically increasing") return nil, false, errors.New("range is not monotonically increasing")
} }
} }
// Create a key-value notary to track which items from the given proof the // Create a key-value notary to track which items from the given proof the
// range prover actually needed to verify the data // range prover actually needed to verify the data
notary := NewKeyValueNotary(proof) notary := newKeyValueNotary(proof)
// Special case, there is no edge proof at all. The given range is expected // Special case, there is no edge proof at all. The given range is expected
// to be the whole leaf-set in the trie. // to be the whole leaf-set in the trie.
if proof == nil { if proof == nil {
var ( var (
diskdb = memorydb.New() diskdb = memorydb.New()
triedb = NewDatabase(diskdb) tr = NewStackTrie(diskdb)
) )
tr, err := New(common.Hash{}, triedb)
if err != nil {
return nil, nil, nil, false, err
}
for index, key := range keys { for index, key := range keys {
tr.TryUpdate(key, values[index]) tr.TryUpdate(key, values[index])
} }
if tr.Hash() != rootHash { if have, want := tr.Hash(), rootHash; have != want {
return nil, nil, nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash()) return nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have)
} }
// Proof seems valid, serialize all the nodes into the database // Proof seems valid, serialize remaining nodes into the database
if _, err := tr.Commit(nil); err != nil { if _, err := tr.Commit(); err != nil {
return nil, nil, nil, false, err return nil, false, err
} }
if err := triedb.Commit(rootHash, false); err != nil { return diskdb, false, nil // No more elements
return nil, nil, nil, false, err
}
return diskdb, tr, notary, false, nil // No more elements
} }
// Special case, there is a provided edge proof but zero key/value // Special case, there is a provided edge proof but zero key/value
// pairs, ensure there are no more accounts / slots in the trie. // pairs, ensure there are no more accounts / slots in the trie.
if len(keys) == 0 { if len(keys) == 0 {
root, val, err := proofToPath(rootHash, nil, firstKey, notary, true) root, val, err := proofToPath(rootHash, nil, firstKey, notary, true)
if err != nil { if err != nil {
return nil, nil, nil, false, err return nil, false, err
} }
if val != nil || hasRightElement(root, firstKey) { if val != nil || hasRightElement(root, firstKey) {
return nil, nil, nil, false, errors.New("more entries available") return nil, false, errors.New("more entries available")
} }
// Since the entire proof is a single path, we can construct a trie and a // Since the entire proof is a single path, we can construct a trie and a
// node database directly out of the inputs, no need to generate them // node database directly out of the inputs, no need to generate them
diskdb := notary.Accessed() diskdb := notary.Accessed()
tr := &Trie{ return diskdb, hasRightElement(root, firstKey), nil
Db: NewDatabase(diskdb),
root: root,
}
return diskdb, tr, notary, hasRightElement(root, firstKey), nil
} }
// Special case, there is only one element and two edge keys are same. // Special case, there is only one element and two edge keys are same.
// In this case, we can't construct two edge paths. So handle it here. // In this case, we can't construct two edge paths. So handle it here.
if len(keys) == 1 && bytes.Equal(firstKey, lastKey) { if len(keys) == 1 && bytes.Equal(firstKey, lastKey) {
root, val, err := proofToPath(rootHash, nil, firstKey, notary, false) root, val, err := proofToPath(rootHash, nil, firstKey, notary, false)
if err != nil { if err != nil {
return nil, nil, nil, false, err return nil, false, err
} }
if !bytes.Equal(firstKey, keys[0]) { if !bytes.Equal(firstKey, keys[0]) {
return nil, nil, nil, false, errors.New("correct proof but invalid key") return nil, false, errors.New("correct proof but invalid key")
} }
if !bytes.Equal(val, values[0]) { if !bytes.Equal(val, values[0]) {
return nil, nil, nil, false, errors.New("correct proof but invalid data") return nil, false, errors.New("correct proof but invalid data")
} }
// Since the entire proof is a single path, we can construct a trie and a // Since the entire proof is a single path, we can construct a trie and a
// node database directly out of the inputs, no need to generate them // node database directly out of the inputs, no need to generate them
diskdb := notary.Accessed() diskdb := notary.Accessed()
tr := &Trie{ return diskdb, hasRightElement(root, firstKey), nil
Db: NewDatabase(diskdb),
root: root,
}
return diskdb, tr, notary, hasRightElement(root, firstKey), nil
} }
// Ok, in all other cases, we require two edge paths available. // Ok, in all other cases, we require two edge paths available.
// First check the validity of edge keys. // First check the validity of edge keys.
if bytes.Compare(firstKey, lastKey) >= 0 { if bytes.Compare(firstKey, lastKey) >= 0 {
return nil, nil, nil, false, errors.New("invalid edge keys") return nil, false, errors.New("invalid edge keys")
} }
// todo(rjl493456442) different length edge keys should be supported // todo(rjl493456442) different length edge keys should be supported
if len(firstKey) != len(lastKey) { if len(firstKey) != len(lastKey) {
return nil, nil, nil, false, errors.New("inconsistent edge keys") return nil, false, errors.New("inconsistent edge keys")
} }
// Convert the edge proofs to edge trie paths. Then we can // Convert the edge proofs to edge trie paths. Then we can
// have the same tree architecture with the original one. // have the same tree architecture with the original one.
// For the first edge proof, non-existent proof is allowed. // For the first edge proof, non-existent proof is allowed.
root, _, err := proofToPath(rootHash, nil, firstKey, notary, true) root, _, err := proofToPath(rootHash, nil, firstKey, notary, true)
if err != nil { if err != nil {
return nil, nil, nil, false, err return nil, false, err
} }
// Pass the root node here, the second path will be merged // Pass the root node here, the second path will be merged
// with the first one. For the last edge proof, non-existent // with the first one. For the last edge proof, non-existent
// proof is also allowed. // proof is also allowed.
root, _, err = proofToPath(rootHash, root, lastKey, notary, true) root, _, err = proofToPath(rootHash, root, lastKey, notary, true)
if err != nil { if err != nil {
return nil, nil, nil, false, err return nil, false, err
} }
// Remove all internal references. All the removed parts should // Remove all internal references. All the removed parts should
// be re-filled(or re-constructed) by the given leaves range. // be re-filled(or re-constructed) by the given leaves range.
empty, err := unsetInternal(root, firstKey, lastKey) empty, err := unsetInternal(root, firstKey, lastKey)
if err != nil { if err != nil {
return nil, nil, nil, false, err return nil, false, err
} }
// Rebuild the trie with the leaf stream, the shape of trie // Rebuild the trie with the leaf stream, the shape of trie
// should be same with the original one. // should be same with the original one.
@ -587,16 +572,16 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key
tr.TryUpdate(key, values[index]) tr.TryUpdate(key, values[index])
} }
if tr.Hash() != rootHash { if tr.Hash() != rootHash {
return nil, nil, nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash()) return nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())
} }
// Proof seems valid, serialize all the nodes into the database // Proof seems valid, serialize all the nodes into the database
if _, err := tr.Commit(nil); err != nil { if _, err := tr.Commit(nil); err != nil {
return nil, nil, nil, false, err return nil, false, err
} }
if err := triedb.Commit(rootHash, false); err != nil { if err := triedb.Commit(rootHash, false); err != nil {
return nil, nil, nil, false, err return nil, false, err
} }
return diskdb, tr, notary, hasRightElement(root, keys[len(keys)-1]), nil return diskdb, hasRightElement(root, keys[len(keys)-1]), nil
} }
// get returns the child of the given Node. Return nil if the // get returns the child of the given Node. Return nil if the

View file

@ -177,7 +177,7 @@ func TestRangeProof(t *testing.T) {
keys = append(keys, entries[i].k) keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v) vals = append(vals, entries[i].v)
} }
_, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
if err != nil { if err != nil {
t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
} }
@ -228,7 +228,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
keys = append(keys, entries[i].k) keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v) vals = append(vals, entries[i].v)
} }
_, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
if err != nil { if err != nil {
t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
} }
@ -249,7 +249,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
_, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof) _, _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
if err != nil { if err != nil {
t.Fatal("Failed to verify whole rang with non-existent edges") t.Fatal("Failed to verify whole rang with non-existent edges")
} }
@ -284,7 +284,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
_, _, _, _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof) _, _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof)
if err == nil { if err == nil {
t.Fatalf("Expected to detect the error, got nil") t.Fatalf("Expected to detect the error, got nil")
} }
@ -306,7 +306,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
_, _, _, _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof) _, _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof)
if err == nil { if err == nil {
t.Fatalf("Expected to detect the error, got nil") t.Fatalf("Expected to detect the error, got nil")
} }
@ -330,7 +330,7 @@ func TestOneElementRangeProof(t *testing.T) {
if err := trie.Prove(entries[start].k, 0, proof); err != nil { if err := trie.Prove(entries[start].k, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
_, _, _, _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) _, _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -345,7 +345,7 @@ func TestOneElementRangeProof(t *testing.T) {
if err := trie.Prove(entries[start].k, 0, proof); err != nil { if err := trie.Prove(entries[start].k, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, _, _, _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) _, _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -360,7 +360,7 @@ func TestOneElementRangeProof(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, _, _, _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) _, _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -375,7 +375,7 @@ func TestOneElementRangeProof(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof) _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -394,7 +394,7 @@ func TestOneElementRangeProof(t *testing.T) {
if err := tinyTrie.Prove(last, 0, proof); err != nil { if err := tinyTrie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, _, _, _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof) _, _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -416,7 +416,7 @@ func TestAllElementsProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
_, _, _, _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil) _, _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -429,7 +429,7 @@ func TestAllElementsProof(t *testing.T) {
if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil { if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, _, _, _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof) _, _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -444,7 +444,7 @@ func TestAllElementsProof(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof) _, _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -477,7 +477,7 @@ func TestSingleSideRangeProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
_, _, _, _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof) _, _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -513,7 +513,7 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
_, _, _, _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof) _, _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -585,7 +585,7 @@ func TestBadRangeProof(t *testing.T) {
index = mrand.Intn(end - start) index = mrand.Intn(end - start)
vals[index] = nil vals[index] = nil
} }
_, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof) _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
if err == nil { if err == nil {
t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1) t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1)
} }
@ -619,7 +619,7 @@ func TestGappedRangeProof(t *testing.T) {
keys = append(keys, entries[i].k) keys = append(keys, entries[i].k)
vals = append(vals, entries[i].v) vals = append(vals, entries[i].v)
} }
_, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
if err == nil { if err == nil {
t.Fatal("expect error, got nil") t.Fatal("expect error, got nil")
} }
@ -646,7 +646,7 @@ func TestSameSideProofs(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) _, _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
if err == nil { if err == nil {
t.Fatalf("Expected error, got nil") t.Fatalf("Expected error, got nil")
} }
@ -662,7 +662,7 @@ func TestSameSideProofs(t *testing.T) {
if err := trie.Prove(last, 0, proof); err != nil { if err := trie.Prove(last, 0, proof); err != nil {
t.Fatalf("Failed to prove the last node %v", err) t.Fatalf("Failed to prove the last node %v", err)
} }
_, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof) _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
if err == nil { if err == nil {
t.Fatalf("Expected error, got nil") t.Fatalf("Expected error, got nil")
} }
@ -730,7 +730,7 @@ func TestHasRightElement(t *testing.T) {
k = append(k, entries[i].k) k = append(k, entries[i].k)
v = append(v, entries[i].v) v = append(v, entries[i].v)
} }
_, _, _, hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof) _, hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof)
if err != nil { if err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
@ -763,25 +763,19 @@ func TestEmptyRangeProof(t *testing.T) {
if err := trie.Prove(first, 0, proof); err != nil { if err := trie.Prove(first, 0, proof); err != nil {
t.Fatalf("Failed to prove the first node %v", err) t.Fatalf("Failed to prove the first node %v", err)
} }
db, tr, not, _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof) db, _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof)
if c.err && err == nil { if c.err && err == nil {
t.Fatalf("Expected error, got nil") t.Fatalf("Expected error, got nil")
} }
if !c.err && err != nil { if !c.err && err != nil {
t.Fatalf("Expected no error, got %v", err) t.Fatalf("Expected no error, got %v", err)
} }
// If no error was returned, ensure the returned trie and database contains // If no error was returned, ensure the returned database contains
// the entire proof, since there's no value // the entire proof, since there's no value
if !c.err { if !c.err {
if err := tr.Prove(first, 0, memorydb.New()); err != nil {
t.Errorf("returned trie doesn't contain original proof: %v", err)
}
if memdb := db.(*memorydb.Database); memdb.Len() != proof.Len() { if memdb := db.(*memorydb.Database); memdb.Len() != proof.Len() {
t.Errorf("database entry count mismatch: have %d, want %d", memdb.Len(), proof.Len()) t.Errorf("database entry count mismatch: have %d, want %d", memdb.Len(), proof.Len())
} }
if not == nil {
t.Errorf("missing notary")
}
} }
} }
} }
@ -800,6 +794,8 @@ func TestBloatedProof(t *testing.T) {
var vals [][]byte var vals [][]byte
proof := memorydb.New() proof := memorydb.New()
// In the 'malicious' case, we add proofs for every single item
// (but only one key/value pair used as leaf)
for i, entry := range entries { for i, entry := range entries {
trie.Prove(entry.k, 0, proof) trie.Prove(entry.k, 0, proof)
if i == 50 { if i == 50 {
@ -807,12 +803,15 @@ func TestBloatedProof(t *testing.T) {
vals = append(vals, entry.v) vals = append(vals, entry.v)
} }
} }
// For reference, we use the same function, but _only_ prove the first
// and last element
want := memorydb.New() want := memorydb.New()
trie.Prove(keys[0], 0, want) trie.Prove(keys[0], 0, want)
trie.Prove(keys[len(keys)-1], 0, want) trie.Prove(keys[len(keys)-1], 0, want)
_, _, notary, _, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof) db, _, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
if used := notary.Accessed().(*memorydb.Database); used.Len() != want.Len() { // The db should not contain anything of the bloated data
if used := db.(*memorydb.Database); used.Len() != want.Len() {
t.Fatalf("notary proof size mismatch: have %d, want %d", used.Len(), want.Len()) t.Fatalf("notary proof size mismatch: have %d, want %d", used.Len(), want.Len())
} }
} }
@ -917,13 +916,40 @@ func benchmarkVerifyRangeProof(b *testing.B, size int) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof) _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof)
if err != nil { if err != nil {
b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err) b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
} }
} }
} }
func BenchmarkVerifyRangeNoProof10(b *testing.B) { benchmarkVerifyRangeNoProof(b, 100) }
func BenchmarkVerifyRangeNoProof500(b *testing.B) { benchmarkVerifyRangeNoProof(b, 500) }
func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof(b, 1000) }
func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
trie, vals := randomTrie(size)
var entries entrySlice
for _, kv := range vals {
entries = append(entries, kv)
}
sort.Sort(entries)
var keys [][]byte
var values [][]byte
for _, entry := range entries {
keys = append(keys, entry.k)
values = append(values, entry.v)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, nil)
if err != nil {
b.Fatalf("Expected no error, got %v", err)
}
}
}
func randomTrie(n int) (*Trie, map[string]*kv) { func randomTrie(n int) (*Trie, map[string]*kv) {
trie := new(Trie) trie := new(Trie)
vals := make(map[string]*kv) vals := make(map[string]*kv)