From 4dfec7e83e4634cc8ab254b4bdbe9e692142316b Mon Sep 17 00:00:00 2001 From: rjl493456442 Date: Tue, 25 Mar 2025 21:59:44 +0800 Subject: [PATCH] trie: optimize memory allocation (#30932) This pull request removes the node copy operation to reduce memory allocation. Key Changes as below: **(a) Use `decodeNodeUnsafe` for decoding nodes retrieved from the trie node reader** In the current implementation of the MPT, once a trie node blob is retrieved, it is passed to `decodeNode` for decoding. However, `decodeNode` assumes the supplied byte slice might be mutated later, so it performs a deep copy internally before parsing the node. Given that the node reader is implemented by the path database and the hash database, both of which guarantee the immutability of the returned byte slice. By restricting the node reader interface to explicitly guarantee that the returned byte slice will not be modified, we can safely replace `decodeNode` with `decodeNodeUnsafe`. This eliminates the need for a redundant byte copy during each node resolution. **(b) Modify the trie in place** In the current implementation of the MPT, a copy of a trie node is created before any modifications are made. These modifications include: - Node resolution: Converting the value from a hash to the actual node. - Node hashing: Tagging the hash into its cache. - Node commit: Replacing the children with its hash. - Structural changes: For example, adding a new child to a fullNode or replacing a child of a shortNode. This mechanism ensures that modifications only affect the live tree, leaving all previously created copies unaffected. Unfortunately, this property leads to a huge memory allocation requirement. For example, if we want to modify the fullNode for n times, the node will be copied for n times. In this pull request, all the trie modifications are made in place. In order to make sure all previously created copies are unaffected, the `Copy` function now will deep-copy all the live nodes rather than the root node itself. With this change, while the `Copy` function becomes more expensive, it's totally acceptable as it's not a frequently used one. For the normal trie operations (Get, GetNode, Hash, Commit, Insert, Delete), the node copy is not required anymore. --- trie/committer.go | 39 +++------ trie/hasher.go | 84 +++++++++--------- trie/node.go | 14 ++- trie/trie.go | 68 ++++++++++----- trie/trie_test.go | 168 ++++++++++++++++++++++++++++++++++++ triedb/database/database.go | 2 + 6 files changed, 279 insertions(+), 96 deletions(-) diff --git a/trie/committer.go b/trie/committer.go index 6c4374ccfd..0939a07abb 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -57,32 +57,26 @@ func (c *committer) commit(path []byte, n node, parallel bool) node { // Commit children, then parent, and remove the dirty flag. switch cn := n.(type) { case *shortNode: - // Commit child - collapsed := cn.copy() - // If the child is fullNode, recursively commit, // otherwise it can only be hashNode or valueNode. if _, ok := cn.Val.(*fullNode); ok { - collapsed.Val = c.commit(append(path, cn.Key...), cn.Val, false) + cn.Val = c.commit(append(path, cn.Key...), cn.Val, false) } // The key needs to be copied, since we're adding it to the // modified nodeset. - collapsed.Key = hexToCompact(cn.Key) - hashedNode := c.store(path, collapsed) + cn.Key = hexToCompact(cn.Key) + hashedNode := c.store(path, cn) if hn, ok := hashedNode.(hashNode); ok { return hn } - return collapsed + return cn case *fullNode: - hashedKids := c.commitChildren(path, cn, parallel) - collapsed := cn.copy() - collapsed.Children = hashedKids - - hashedNode := c.store(path, collapsed) + c.commitChildren(path, cn, parallel) + hashedNode := c.store(path, cn) if hn, ok := hashedNode.(hashNode); ok { return hn } - return collapsed + return cn case hashNode: return cn default: @@ -92,11 +86,10 @@ func (c *committer) commit(path []byte, n node, parallel bool) node { } // commitChildren commits the children of the given fullnode -func (c *committer) commitChildren(path []byte, n *fullNode, parallel bool) [17]node { +func (c *committer) commitChildren(path []byte, n *fullNode, parallel bool) { var ( - wg sync.WaitGroup - nodesMu sync.Mutex - children [17]node + wg sync.WaitGroup + nodesMu sync.Mutex ) for i := 0; i < 16; i++ { child := n.Children[i] @@ -106,22 +99,21 @@ func (c *committer) commitChildren(path []byte, n *fullNode, parallel bool) [17] // If it's the hashed child, save the hash value directly. // Note: it's impossible that the child in range [0, 15] // is a valueNode. - if hn, ok := child.(hashNode); ok { - children[i] = hn + if _, ok := child.(hashNode); ok { continue } // Commit the child recursively and store the "hashed" value. // Note the returned node can be some embedded nodes, so it's // possible the type is not hashNode. if !parallel { - children[i] = c.commit(append(path, byte(i)), child, false) + n.Children[i] = c.commit(append(path, byte(i)), child, false) } else { wg.Add(1) go func(index int) { p := append(path, byte(index)) childSet := trienode.NewNodeSet(c.nodes.Owner) childCommitter := newCommitter(childSet, c.tracer, c.collectLeaf) - children[index] = childCommitter.commit(p, child, false) + n.Children[index] = childCommitter.commit(p, child, false) nodesMu.Lock() c.nodes.MergeSet(childSet) nodesMu.Unlock() @@ -132,11 +124,6 @@ func (c *committer) commitChildren(path []byte, n *fullNode, parallel bool) [17] if parallel { wg.Wait() } - // For the 17th child, it's possible the type is valuenode. - if n.Children[16] != nil { - children[16] = n.Children[16] - } - return children } // store hashes the node n and adds it to the modified nodeset. If leaf collection diff --git a/trie/hasher.go b/trie/hasher.go index 28f7f3d0c3..614640ae3a 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -53,62 +53,56 @@ func returnHasherToPool(h *hasher) { hasherPool.Put(h) } -// hash collapses a node down into a hash node, also returning a copy of the -// original node initialized with the computed hash to replace the original one. -func (h *hasher) hash(n node, force bool) (hashed node, cached node) { +// hash collapses a node down into a hash node. +func (h *hasher) hash(n node, force bool) node { // Return the cached hash if it's available if hash, _ := n.cache(); hash != nil { - return hash, n + return hash } // Trie not processed yet, walk the children switch n := n.(type) { case *shortNode: - collapsed, cached := h.hashShortNodeChildren(n) + collapsed := h.hashShortNodeChildren(n) hashed := h.shortnodeToHash(collapsed, force) - // We need to retain the possibly _not_ hashed node, in case it was too - // small to be hashed if hn, ok := hashed.(hashNode); ok { - cached.flags.hash = hn + n.flags.hash = hn } else { - cached.flags.hash = nil + n.flags.hash = nil } - return hashed, cached + return hashed case *fullNode: - collapsed, cached := h.hashFullNodeChildren(n) - hashed = h.fullnodeToHash(collapsed, force) + collapsed := h.hashFullNodeChildren(n) + hashed := h.fullnodeToHash(collapsed, force) if hn, ok := hashed.(hashNode); ok { - cached.flags.hash = hn + n.flags.hash = hn } else { - cached.flags.hash = nil + n.flags.hash = nil } - return hashed, cached + return hashed default: // Value and hash nodes don't have children, so they're left as were - return n, n + return n } } -// hashShortNodeChildren collapses the short node. The returned collapsed node -// holds a live reference to the Key, and must not be modified. -func (h *hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNode) { - // Hash the short node's child, caching the newly hashed subtree - collapsed, cached = n.copy(), n.copy() - // Previously, we did copy this one. We don't seem to need to actually - // do that, since we don't overwrite/reuse keys - // cached.Key = common.CopyBytes(n.Key) +// hashShortNodeChildren returns a copy of the supplied shortNode, with its child +// being replaced by either the hash or an embedded node if the child is small. +func (h *hasher) hashShortNodeChildren(n *shortNode) *shortNode { + var collapsed shortNode collapsed.Key = hexToCompact(n.Key) - // Unless the child is a valuenode or hashnode, hash it switch n.Val.(type) { case *fullNode, *shortNode: - collapsed.Val, cached.Val = h.hash(n.Val, false) + collapsed.Val = h.hash(n.Val, false) + default: + collapsed.Val = n.Val } - return collapsed, cached + return &collapsed } -func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached *fullNode) { - // Hash the full node's children, caching the newly hashed subtrees - cached = n.copy() - collapsed = n.copy() +// hashFullNodeChildren returns a copy of the supplied fullNode, with its child +// being replaced by either the hash or an embedded node if the child is small. +func (h *hasher) hashFullNodeChildren(n *fullNode) *fullNode { + var children [17]node if h.parallel { var wg sync.WaitGroup wg.Add(16) @@ -116,9 +110,9 @@ func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached go func(i int) { hasher := newHasher(false) if child := n.Children[i]; child != nil { - collapsed.Children[i], cached.Children[i] = hasher.hash(child, false) + children[i] = hasher.hash(child, false) } else { - collapsed.Children[i] = nilValueNode + children[i] = nilValueNode } returnHasherToPool(hasher) wg.Done() @@ -128,19 +122,21 @@ func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached } else { for i := 0; i < 16; i++ { if child := n.Children[i]; child != nil { - collapsed.Children[i], cached.Children[i] = h.hash(child, false) + children[i] = h.hash(child, false) } else { - collapsed.Children[i] = nilValueNode + children[i] = nilValueNode } } } - return collapsed, cached + if n.Children[16] != nil { + children[16] = n.Children[16] + } + return &fullNode{flags: nodeFlag{}, Children: children} } -// shortnodeToHash creates a hashNode from a shortNode. The supplied shortnode -// should have hex-type Key, which will be converted (without modification) -// into compact form for RLP encoding. -// If the rlp data is smaller than 32 bytes, `nil` is returned. +// shortNodeToHash computes the hash of the given shortNode. The shortNode must +// first be collapsed, with its key converted to compact form. If the RLP-encoded +// node data is smaller than 32 bytes, the node itself is returned. func (h *hasher) shortnodeToHash(n *shortNode, force bool) node { n.encode(h.encbuf) enc := h.encodedBytes() @@ -151,8 +147,8 @@ func (h *hasher) shortnodeToHash(n *shortNode, force bool) node { return h.hashData(enc) } -// fullnodeToHash is used to create a hashNode from a fullNode, (which -// may contain nil values) +// fullnodeToHash computes the hash of the given fullNode. If the RLP-encoded +// node data is smaller than 32 bytes, the node itself is returned. func (h *hasher) fullnodeToHash(n *fullNode, force bool) node { n.encode(h.encbuf) enc := h.encodedBytes() @@ -203,10 +199,10 @@ func (h *hasher) hashDataTo(dst, data []byte) { func (h *hasher) proofHash(original node) (collapsed, hashed node) { switch n := original.(type) { case *shortNode: - sn, _ := h.hashShortNodeChildren(n) + sn := h.hashShortNodeChildren(n) return sn, h.shortnodeToHash(sn, false) case *fullNode: - fn, _ := h.hashFullNodeChildren(n) + fn := h.hashFullNodeChildren(n) return fn, h.fullnodeToHash(fn, false) default: // Value and hash nodes don't have children, so they're left as were diff --git a/trie/node.go b/trie/node.go index ecc2de192d..96f077ebbb 100644 --- a/trie/node.go +++ b/trie/node.go @@ -79,15 +79,19 @@ func (n *fullNode) EncodeRLP(w io.Writer) error { return eb.Flush() } -func (n *fullNode) copy() *fullNode { copy := *n; return © } -func (n *shortNode) copy() *shortNode { copy := *n; return © } - // nodeFlag contains caching-related metadata about a node. type nodeFlag struct { hash hashNode // cached hash of the node (may be nil) dirty bool // whether the node has changes that must be written to the database } +func (n nodeFlag) copy() nodeFlag { + return nodeFlag{ + hash: common.CopyBytes(n.hash), + dirty: n.dirty, + } +} + func (n *fullNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty } func (n *shortNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty } func (n hashNode) cache() (hashNode, bool) { return nil, true } @@ -228,7 +232,9 @@ func decodeRef(buf []byte) (node, []byte, error) { err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen) return nil, buf, err } - n, err := decodeNode(nil, buf) + // The buffer content has already been copied or is safe to use; + // no additional copy is required. + n, err := decodeNodeUnsafe(nil, buf) return n, rest, err case kind == rlp.String && len(val) == 0: // empty node diff --git a/trie/trie.go b/trie/trie.go index ae2a7b21a2..fdb4da9be4 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -29,11 +29,11 @@ import ( "github.com/ethereum/go-ethereum/triedb/database" ) -// Trie is a Merkle Patricia Trie. Use New to create a trie that sits on -// top of a database. Whenever trie performs a commit operation, the generated -// nodes will be gathered and returned in a set. Once the trie is committed, -// it's not usable anymore. Callers have to re-create the trie with new root -// based on the updated trie database. +// Trie represents a Merkle Patricia Trie. Use New to create a trie that operates +// on top of a node database. During a commit operation, the trie collects all +// modified nodes into a set for return. After committing, the trie becomes +// unusable, and callers must recreate it with the new root based on the updated +// trie database. // // Trie is not safe for concurrent use. type Trie struct { @@ -67,13 +67,13 @@ func (t *Trie) newFlag() nodeFlag { // Copy returns a copy of Trie. func (t *Trie) Copy() *Trie { return &Trie{ - root: t.root, + root: copyNode(t.root), owner: t.owner, committed: t.committed, + unhashed: t.unhashed, + uncommitted: t.uncommitted, reader: t.reader, tracer: t.tracer.copy(), - uncommitted: t.uncommitted, - unhashed: t.unhashed, } } @@ -169,14 +169,12 @@ func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, newnode no } value, newnode, didResolve, err = t.get(n.Val, key, pos+len(n.Key)) if err == nil && didResolve { - n = n.copy() n.Val = newnode } return value, n, didResolve, err case *fullNode: value, newnode, didResolve, err = t.get(n.Children[key[pos]], key, pos+1) if err == nil && didResolve { - n = n.copy() n.Children[key[pos]] = newnode } return value, n, didResolve, err @@ -257,7 +255,6 @@ func (t *Trie) getNode(origNode node, path []byte, pos int) (item []byte, newnod } item, newnode, resolved, err = t.getNode(n.Val, path, pos+len(n.Key)) if err == nil && resolved > 0 { - n = n.copy() n.Val = newnode } return item, n, resolved, err @@ -265,7 +262,6 @@ func (t *Trie) getNode(origNode node, path []byte, pos int) (item []byte, newnod case *fullNode: item, newnode, resolved, err = t.getNode(n.Children[path[pos]], path, pos+1) if err == nil && resolved > 0 { - n = n.copy() n.Children[path[pos]] = newnode } return item, n, resolved, err @@ -375,7 +371,6 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error if !dirty || err != nil { return false, n, err } - n = n.copy() n.flags = t.newFlag() n.Children[key[0]] = nn return true, n, nil @@ -483,7 +478,6 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { if !dirty || err != nil { return false, n, err } - n = n.copy() n.flags = t.newFlag() n.Children[key[0]] = nn @@ -576,6 +570,36 @@ func concat(s1 []byte, s2 ...byte) []byte { return r } +// copyNode deep-copies the supplied node along with its children recursively. +func copyNode(n node) node { + switch n := (n).(type) { + case nil: + return nil + case valueNode: + return valueNode(common.CopyBytes(n)) + + case *shortNode: + return &shortNode{ + flags: n.flags.copy(), + Key: common.CopyBytes(n.Key), + Val: copyNode(n.Val), + } + case *fullNode: + var children [17]node + for i, cn := range n.Children { + children[i] = copyNode(cn) + } + return &fullNode{ + flags: n.flags.copy(), + Children: children, + } + case hashNode: + return n + default: + panic(fmt.Sprintf("%T: unknown node type", n)) + } +} + func (t *Trie) resolve(n node, prefix []byte) (node, error) { if n, ok := n.(hashNode); ok { return t.resolveAndTrack(n, prefix) @@ -593,15 +617,16 @@ func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) { return nil, err } t.tracer.onRead(prefix, blob) - return mustDecodeNode(n, blob), nil + + // The returned node blob won't be changed afterward. No need to + // deep-copy the slice. + return decodeNodeUnsafe(n, blob) } // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash { - hash, cached := t.hashRoot() - t.root = cached - return common.BytesToHash(hash.(hashNode)) + return common.BytesToHash(t.hashRoot().(hashNode)) } // Commit collects all dirty nodes in the trie and replaces them with the @@ -652,9 +677,9 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet) { } // hashRoot calculates the root hash of the given trie -func (t *Trie) hashRoot() (node, node) { +func (t *Trie) hashRoot() node { if t.root == nil { - return hashNode(types.EmptyRootHash.Bytes()), nil + return hashNode(types.EmptyRootHash.Bytes()) } // If the number of changes is below 100, we let one thread handle it h := newHasher(t.unhashed >= 100) @@ -662,8 +687,7 @@ func (t *Trie) hashRoot() (node, node) { returnHasherToPool(h) t.unhashed = 0 }() - hashed, cached := h.hash(t.root, true) - return hashed, cached + return h.hash(t.root, true) } // Witness returns a set containing all trie nodes that have been accessed. diff --git a/trie/trie_test.go b/trie/trie_test.go index 77234d9d9b..54d1b083d8 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -1330,3 +1330,171 @@ func printSet(set *trienode.NodeSet) string { } return out.String() } + +func TestTrieCopy(t *testing.T) { + testTrieCopy(t, []kv{ + {k: []byte("do"), v: []byte("verb")}, + {k: []byte("ether"), v: []byte("wookiedoo")}, + {k: []byte("horse"), v: []byte("stallion")}, + {k: []byte("shaman"), v: []byte("horse")}, + {k: []byte("doge"), v: []byte("coin")}, + {k: []byte("dog"), v: []byte("puppy")}, + }) + + var entries []kv + for i := 0; i < 256; i++ { + entries = append(entries, kv{k: testrand.Bytes(32), v: testrand.Bytes(32)}) + } + testTrieCopy(t, entries) +} + +func testTrieCopy(t *testing.T, entries []kv) { + tr := NewEmpty(nil) + for _, entry := range entries { + tr.Update(entry.k, entry.v) + } + trCpy := tr.Copy() + + if tr.Hash() != trCpy.Hash() { + t.Errorf("Hash mismatch: old %v, copy %v", tr.Hash(), trCpy.Hash()) + } + + // Check iterator + it, _ := tr.NodeIterator(nil) + itCpy, _ := trCpy.NodeIterator(nil) + + for it.Next(false) { + hasNext := itCpy.Next(false) + if !hasNext { + t.Fatal("Iterator is not matched") + } + if !bytes.Equal(it.Path(), itCpy.Path()) { + t.Fatal("Iterator is not matched") + } + if it.Leaf() != itCpy.Leaf() { + t.Fatal("Iterator is not matched") + } + if it.Leaf() && !bytes.Equal(it.LeafBlob(), itCpy.LeafBlob()) { + t.Fatal("Iterator is not matched") + } + } + + // Check commit + root, nodes := tr.Commit(false) + rootCpy, nodesCpy := trCpy.Commit(false) + if root != rootCpy { + t.Fatal("root mismatch") + } + if len(nodes.Nodes) != len(nodesCpy.Nodes) { + t.Fatal("commit node mismatch") + } + for p, n := range nodes.Nodes { + nn, exists := nodesCpy.Nodes[p] + if !exists { + t.Fatalf("node not exists: %v", p) + } + if !reflect.DeepEqual(n, nn) { + t.Fatalf("node mismatch: %v", p) + } + } +} + +func TestTrieCopyOldTrie(t *testing.T) { + testTrieCopyOldTrie(t, []kv{ + {k: []byte("do"), v: []byte("verb")}, + {k: []byte("ether"), v: []byte("wookiedoo")}, + {k: []byte("horse"), v: []byte("stallion")}, + {k: []byte("shaman"), v: []byte("horse")}, + {k: []byte("doge"), v: []byte("coin")}, + {k: []byte("dog"), v: []byte("puppy")}, + }) + + var entries []kv + for i := 0; i < 256; i++ { + entries = append(entries, kv{k: testrand.Bytes(32), v: testrand.Bytes(32)}) + } + testTrieCopyOldTrie(t, entries) +} + +func testTrieCopyOldTrie(t *testing.T, entries []kv) { + tr := NewEmpty(nil) + for _, entry := range entries { + tr.Update(entry.k, entry.v) + } + hash := tr.Hash() + + trCpy := tr.Copy() + for _, val := range entries { + if rand.Intn(2) == 0 { + trCpy.Delete(val.k) + } else { + trCpy.Update(val.k, testrand.Bytes(32)) + } + } + for i := 0; i < 10; i++ { + trCpy.Update(testrand.Bytes(32), testrand.Bytes(32)) + } + trCpy.Hash() + trCpy.Commit(false) + + // Traverse the original tree, the changes made on the copy one shouldn't + // affect the old one + for _, entry := range entries { + d, _ := tr.Get(entry.k) + if !bytes.Equal(d, entry.v) { + t.Errorf("Unexpected data, key: %v, want: %v, got: %v", entry.k, entry.v, d) + } + } + if tr.Hash() != hash { + t.Errorf("Hash mismatch: old %v, new %v", hash, tr.Hash()) + } +} + +func TestTrieCopyNewTrie(t *testing.T) { + testTrieCopyNewTrie(t, []kv{ + {k: []byte("do"), v: []byte("verb")}, + {k: []byte("ether"), v: []byte("wookiedoo")}, + {k: []byte("horse"), v: []byte("stallion")}, + {k: []byte("shaman"), v: []byte("horse")}, + {k: []byte("doge"), v: []byte("coin")}, + {k: []byte("dog"), v: []byte("puppy")}, + }) + + var entries []kv + for i := 0; i < 256; i++ { + entries = append(entries, kv{k: testrand.Bytes(32), v: testrand.Bytes(32)}) + } + testTrieCopyNewTrie(t, entries) +} + +func testTrieCopyNewTrie(t *testing.T, entries []kv) { + tr := NewEmpty(nil) + for _, entry := range entries { + tr.Update(entry.k, entry.v) + } + trCpy := tr.Copy() + hash := trCpy.Hash() + + for _, val := range entries { + if rand.Intn(2) == 0 { + tr.Delete(val.k) + } else { + tr.Update(val.k, testrand.Bytes(32)) + } + } + for i := 0; i < 10; i++ { + tr.Update(testrand.Bytes(32), testrand.Bytes(32)) + } + + // Traverse the original tree, the changes made on the copy one shouldn't + // affect the old one + for _, entry := range entries { + d, _ := trCpy.Get(entry.k) + if !bytes.Equal(d, entry.v) { + t.Errorf("Unexpected data, key: %v, want: %v, got: %v", entry.k, entry.v, d) + } + } + if trCpy.Hash() != hash { + t.Errorf("Hash mismatch: old %v, new %v", hash, tr.Hash()) + } +} diff --git a/triedb/database/database.go b/triedb/database/database.go index cd7ec1d931..8c61ea0293 100644 --- a/triedb/database/database.go +++ b/triedb/database/database.go @@ -27,6 +27,8 @@ type NodeReader interface { // node path and the corresponding node hash. No error will be returned // if the node is not found. // + // The returned node content won't be changed after the call. + // // Don't modify the returned byte slice since it's not deep-copied and // still be referenced by database. Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error)