1
0
Fork 0
forked from forks/go-ethereum

trie: optimize memory allocation (#30932)

This pull request removes the node copy operation to reduce memory
allocation. Key Changes as below:

**(a) Use `decodeNodeUnsafe` for decoding nodes retrieved from the trie
node reader**

In the current implementation of the MPT, once a trie node blob is
retrieved, it is passed to `decodeNode` for decoding. However,
`decodeNode` assumes the supplied byte slice might be mutated later, so
it performs a deep copy internally before parsing the node.

Given that the node reader is implemented by the path database and the
hash database, both of which guarantee the immutability of the returned
byte slice. By restricting the node reader interface to explicitly
guarantee that the returned byte slice will not be modified, we can
safely replace `decodeNode` with `decodeNodeUnsafe`. This eliminates the
need for a redundant byte copy during each node resolution.

**(b) Modify the trie in place**

In the current implementation of the MPT, a copy of a trie node is
created before any modifications are made. These modifications include:
- Node resolution: Converting the value from a hash to the actual node.
- Node hashing: Tagging the hash into its cache.
- Node commit: Replacing the children with its hash.
- Structural changes: For example, adding a new child to a fullNode or
replacing a child of a shortNode.

This mechanism ensures that modifications only affect the live tree,
leaving all previously created copies unaffected.

Unfortunately, this property leads to a huge memory allocation
requirement. For example, if we want to modify the fullNode for n times,
the node will be copied for n times.

In this pull request, all the trie modifications are made in place. In
order to make sure all previously created copies are unaffected, the
`Copy` function now will deep-copy all the live nodes rather than the
root node itself.

With this change, while the `Copy` function becomes more expensive, it's
totally acceptable as it's not a frequently used one. For the normal
trie operations (Get, GetNode, Hash, Commit, Insert, Delete), the node
copy is not required anymore.
This commit is contained in:
rjl493456442 2025-03-25 21:59:44 +08:00 committed by GitHub
parent 4ff5093df1
commit 4dfec7e83e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 279 additions and 96 deletions

View file

@ -57,32 +57,26 @@ func (c *committer) commit(path []byte, n node, parallel bool) node {
// Commit children, then parent, and remove the dirty flag.
switch cn := n.(type) {
case *shortNode:
// Commit child
collapsed := cn.copy()
// If the child is fullNode, recursively commit,
// otherwise it can only be hashNode or valueNode.
if _, ok := cn.Val.(*fullNode); ok {
collapsed.Val = c.commit(append(path, cn.Key...), cn.Val, false)
cn.Val = c.commit(append(path, cn.Key...), cn.Val, false)
}
// The key needs to be copied, since we're adding it to the
// modified nodeset.
collapsed.Key = hexToCompact(cn.Key)
hashedNode := c.store(path, collapsed)
cn.Key = hexToCompact(cn.Key)
hashedNode := c.store(path, cn)
if hn, ok := hashedNode.(hashNode); ok {
return hn
}
return collapsed
return cn
case *fullNode:
hashedKids := c.commitChildren(path, cn, parallel)
collapsed := cn.copy()
collapsed.Children = hashedKids
hashedNode := c.store(path, collapsed)
c.commitChildren(path, cn, parallel)
hashedNode := c.store(path, cn)
if hn, ok := hashedNode.(hashNode); ok {
return hn
}
return collapsed
return cn
case hashNode:
return cn
default:
@ -92,11 +86,10 @@ func (c *committer) commit(path []byte, n node, parallel bool) node {
}
// commitChildren commits the children of the given fullnode
func (c *committer) commitChildren(path []byte, n *fullNode, parallel bool) [17]node {
func (c *committer) commitChildren(path []byte, n *fullNode, parallel bool) {
var (
wg sync.WaitGroup
nodesMu sync.Mutex
children [17]node
wg sync.WaitGroup
nodesMu sync.Mutex
)
for i := 0; i < 16; i++ {
child := n.Children[i]
@ -106,22 +99,21 @@ func (c *committer) commitChildren(path []byte, n *fullNode, parallel bool) [17]
// If it's the hashed child, save the hash value directly.
// Note: it's impossible that the child in range [0, 15]
// is a valueNode.
if hn, ok := child.(hashNode); ok {
children[i] = hn
if _, ok := child.(hashNode); ok {
continue
}
// Commit the child recursively and store the "hashed" value.
// Note the returned node can be some embedded nodes, so it's
// possible the type is not hashNode.
if !parallel {
children[i] = c.commit(append(path, byte(i)), child, false)
n.Children[i] = c.commit(append(path, byte(i)), child, false)
} else {
wg.Add(1)
go func(index int) {
p := append(path, byte(index))
childSet := trienode.NewNodeSet(c.nodes.Owner)
childCommitter := newCommitter(childSet, c.tracer, c.collectLeaf)
children[index] = childCommitter.commit(p, child, false)
n.Children[index] = childCommitter.commit(p, child, false)
nodesMu.Lock()
c.nodes.MergeSet(childSet)
nodesMu.Unlock()
@ -132,11 +124,6 @@ func (c *committer) commitChildren(path []byte, n *fullNode, parallel bool) [17]
if parallel {
wg.Wait()
}
// For the 17th child, it's possible the type is valuenode.
if n.Children[16] != nil {
children[16] = n.Children[16]
}
return children
}
// store hashes the node n and adds it to the modified nodeset. If leaf collection

View file

@ -53,62 +53,56 @@ func returnHasherToPool(h *hasher) {
hasherPool.Put(h)
}
// hash collapses a node down into a hash node, also returning a copy of the
// original node initialized with the computed hash to replace the original one.
func (h *hasher) hash(n node, force bool) (hashed node, cached node) {
// hash collapses a node down into a hash node.
func (h *hasher) hash(n node, force bool) node {
// Return the cached hash if it's available
if hash, _ := n.cache(); hash != nil {
return hash, n
return hash
}
// Trie not processed yet, walk the children
switch n := n.(type) {
case *shortNode:
collapsed, cached := h.hashShortNodeChildren(n)
collapsed := h.hashShortNodeChildren(n)
hashed := h.shortnodeToHash(collapsed, force)
// We need to retain the possibly _not_ hashed node, in case it was too
// small to be hashed
if hn, ok := hashed.(hashNode); ok {
cached.flags.hash = hn
n.flags.hash = hn
} else {
cached.flags.hash = nil
n.flags.hash = nil
}
return hashed, cached
return hashed
case *fullNode:
collapsed, cached := h.hashFullNodeChildren(n)
hashed = h.fullnodeToHash(collapsed, force)
collapsed := h.hashFullNodeChildren(n)
hashed := h.fullnodeToHash(collapsed, force)
if hn, ok := hashed.(hashNode); ok {
cached.flags.hash = hn
n.flags.hash = hn
} else {
cached.flags.hash = nil
n.flags.hash = nil
}
return hashed, cached
return hashed
default:
// Value and hash nodes don't have children, so they're left as were
return n, n
return n
}
}
// hashShortNodeChildren collapses the short node. The returned collapsed node
// holds a live reference to the Key, and must not be modified.
func (h *hasher) hashShortNodeChildren(n *shortNode) (collapsed, cached *shortNode) {
// Hash the short node's child, caching the newly hashed subtree
collapsed, cached = n.copy(), n.copy()
// Previously, we did copy this one. We don't seem to need to actually
// do that, since we don't overwrite/reuse keys
// cached.Key = common.CopyBytes(n.Key)
// hashShortNodeChildren returns a copy of the supplied shortNode, with its child
// being replaced by either the hash or an embedded node if the child is small.
func (h *hasher) hashShortNodeChildren(n *shortNode) *shortNode {
var collapsed shortNode
collapsed.Key = hexToCompact(n.Key)
// Unless the child is a valuenode or hashnode, hash it
switch n.Val.(type) {
case *fullNode, *shortNode:
collapsed.Val, cached.Val = h.hash(n.Val, false)
collapsed.Val = h.hash(n.Val, false)
default:
collapsed.Val = n.Val
}
return collapsed, cached
return &collapsed
}
func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached *fullNode) {
// Hash the full node's children, caching the newly hashed subtrees
cached = n.copy()
collapsed = n.copy()
// hashFullNodeChildren returns a copy of the supplied fullNode, with its child
// being replaced by either the hash or an embedded node if the child is small.
func (h *hasher) hashFullNodeChildren(n *fullNode) *fullNode {
var children [17]node
if h.parallel {
var wg sync.WaitGroup
wg.Add(16)
@ -116,9 +110,9 @@ func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached
go func(i int) {
hasher := newHasher(false)
if child := n.Children[i]; child != nil {
collapsed.Children[i], cached.Children[i] = hasher.hash(child, false)
children[i] = hasher.hash(child, false)
} else {
collapsed.Children[i] = nilValueNode
children[i] = nilValueNode
}
returnHasherToPool(hasher)
wg.Done()
@ -128,19 +122,21 @@ func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached
} else {
for i := 0; i < 16; i++ {
if child := n.Children[i]; child != nil {
collapsed.Children[i], cached.Children[i] = h.hash(child, false)
children[i] = h.hash(child, false)
} else {
collapsed.Children[i] = nilValueNode
children[i] = nilValueNode
}
}
}
return collapsed, cached
if n.Children[16] != nil {
children[16] = n.Children[16]
}
return &fullNode{flags: nodeFlag{}, Children: children}
}
// shortnodeToHash creates a hashNode from a shortNode. The supplied shortnode
// should have hex-type Key, which will be converted (without modification)
// into compact form for RLP encoding.
// If the rlp data is smaller than 32 bytes, `nil` is returned.
// shortNodeToHash computes the hash of the given shortNode. The shortNode must
// first be collapsed, with its key converted to compact form. If the RLP-encoded
// node data is smaller than 32 bytes, the node itself is returned.
func (h *hasher) shortnodeToHash(n *shortNode, force bool) node {
n.encode(h.encbuf)
enc := h.encodedBytes()
@ -151,8 +147,8 @@ func (h *hasher) shortnodeToHash(n *shortNode, force bool) node {
return h.hashData(enc)
}
// fullnodeToHash is used to create a hashNode from a fullNode, (which
// may contain nil values)
// fullnodeToHash computes the hash of the given fullNode. If the RLP-encoded
// node data is smaller than 32 bytes, the node itself is returned.
func (h *hasher) fullnodeToHash(n *fullNode, force bool) node {
n.encode(h.encbuf)
enc := h.encodedBytes()
@ -203,10 +199,10 @@ func (h *hasher) hashDataTo(dst, data []byte) {
func (h *hasher) proofHash(original node) (collapsed, hashed node) {
switch n := original.(type) {
case *shortNode:
sn, _ := h.hashShortNodeChildren(n)
sn := h.hashShortNodeChildren(n)
return sn, h.shortnodeToHash(sn, false)
case *fullNode:
fn, _ := h.hashFullNodeChildren(n)
fn := h.hashFullNodeChildren(n)
return fn, h.fullnodeToHash(fn, false)
default:
// Value and hash nodes don't have children, so they're left as were

View file

@ -79,15 +79,19 @@ func (n *fullNode) EncodeRLP(w io.Writer) error {
return eb.Flush()
}
func (n *fullNode) copy() *fullNode { copy := *n; return &copy }
func (n *shortNode) copy() *shortNode { copy := *n; return &copy }
// nodeFlag contains caching-related metadata about a node.
type nodeFlag struct {
hash hashNode // cached hash of the node (may be nil)
dirty bool // whether the node has changes that must be written to the database
}
func (n nodeFlag) copy() nodeFlag {
return nodeFlag{
hash: common.CopyBytes(n.hash),
dirty: n.dirty,
}
}
func (n *fullNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
func (n *shortNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
func (n hashNode) cache() (hashNode, bool) { return nil, true }
@ -228,7 +232,9 @@ func decodeRef(buf []byte) (node, []byte, error) {
err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
return nil, buf, err
}
n, err := decodeNode(nil, buf)
// The buffer content has already been copied or is safe to use;
// no additional copy is required.
n, err := decodeNodeUnsafe(nil, buf)
return n, rest, err
case kind == rlp.String && len(val) == 0:
// empty node

View file

@ -29,11 +29,11 @@ import (
"github.com/ethereum/go-ethereum/triedb/database"
)
// Trie is a Merkle Patricia Trie. Use New to create a trie that sits on
// top of a database. Whenever trie performs a commit operation, the generated
// nodes will be gathered and returned in a set. Once the trie is committed,
// it's not usable anymore. Callers have to re-create the trie with new root
// based on the updated trie database.
// Trie represents a Merkle Patricia Trie. Use New to create a trie that operates
// on top of a node database. During a commit operation, the trie collects all
// modified nodes into a set for return. After committing, the trie becomes
// unusable, and callers must recreate it with the new root based on the updated
// trie database.
//
// Trie is not safe for concurrent use.
type Trie struct {
@ -67,13 +67,13 @@ func (t *Trie) newFlag() nodeFlag {
// Copy returns a copy of Trie.
func (t *Trie) Copy() *Trie {
return &Trie{
root: t.root,
root: copyNode(t.root),
owner: t.owner,
committed: t.committed,
unhashed: t.unhashed,
uncommitted: t.uncommitted,
reader: t.reader,
tracer: t.tracer.copy(),
uncommitted: t.uncommitted,
unhashed: t.unhashed,
}
}
@ -169,14 +169,12 @@ func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, newnode no
}
value, newnode, didResolve, err = t.get(n.Val, key, pos+len(n.Key))
if err == nil && didResolve {
n = n.copy()
n.Val = newnode
}
return value, n, didResolve, err
case *fullNode:
value, newnode, didResolve, err = t.get(n.Children[key[pos]], key, pos+1)
if err == nil && didResolve {
n = n.copy()
n.Children[key[pos]] = newnode
}
return value, n, didResolve, err
@ -257,7 +255,6 @@ func (t *Trie) getNode(origNode node, path []byte, pos int) (item []byte, newnod
}
item, newnode, resolved, err = t.getNode(n.Val, path, pos+len(n.Key))
if err == nil && resolved > 0 {
n = n.copy()
n.Val = newnode
}
return item, n, resolved, err
@ -265,7 +262,6 @@ func (t *Trie) getNode(origNode node, path []byte, pos int) (item []byte, newnod
case *fullNode:
item, newnode, resolved, err = t.getNode(n.Children[path[pos]], path, pos+1)
if err == nil && resolved > 0 {
n = n.copy()
n.Children[path[pos]] = newnode
}
return item, n, resolved, err
@ -375,7 +371,6 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
if !dirty || err != nil {
return false, n, err
}
n = n.copy()
n.flags = t.newFlag()
n.Children[key[0]] = nn
return true, n, nil
@ -483,7 +478,6 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
if !dirty || err != nil {
return false, n, err
}
n = n.copy()
n.flags = t.newFlag()
n.Children[key[0]] = nn
@ -576,6 +570,36 @@ func concat(s1 []byte, s2 ...byte) []byte {
return r
}
// copyNode deep-copies the supplied node along with its children recursively.
func copyNode(n node) node {
switch n := (n).(type) {
case nil:
return nil
case valueNode:
return valueNode(common.CopyBytes(n))
case *shortNode:
return &shortNode{
flags: n.flags.copy(),
Key: common.CopyBytes(n.Key),
Val: copyNode(n.Val),
}
case *fullNode:
var children [17]node
for i, cn := range n.Children {
children[i] = copyNode(cn)
}
return &fullNode{
flags: n.flags.copy(),
Children: children,
}
case hashNode:
return n
default:
panic(fmt.Sprintf("%T: unknown node type", n))
}
}
func (t *Trie) resolve(n node, prefix []byte) (node, error) {
if n, ok := n.(hashNode); ok {
return t.resolveAndTrack(n, prefix)
@ -593,15 +617,16 @@ func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) {
return nil, err
}
t.tracer.onRead(prefix, blob)
return mustDecodeNode(n, blob), nil
// The returned node blob won't be changed afterward. No need to
// deep-copy the slice.
return decodeNodeUnsafe(n, blob)
}
// Hash returns the root hash of the trie. It does not write to the
// database and can be used even if the trie doesn't have one.
func (t *Trie) Hash() common.Hash {
hash, cached := t.hashRoot()
t.root = cached
return common.BytesToHash(hash.(hashNode))
return common.BytesToHash(t.hashRoot().(hashNode))
}
// Commit collects all dirty nodes in the trie and replaces them with the
@ -652,9 +677,9 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet) {
}
// hashRoot calculates the root hash of the given trie
func (t *Trie) hashRoot() (node, node) {
func (t *Trie) hashRoot() node {
if t.root == nil {
return hashNode(types.EmptyRootHash.Bytes()), nil
return hashNode(types.EmptyRootHash.Bytes())
}
// If the number of changes is below 100, we let one thread handle it
h := newHasher(t.unhashed >= 100)
@ -662,8 +687,7 @@ func (t *Trie) hashRoot() (node, node) {
returnHasherToPool(h)
t.unhashed = 0
}()
hashed, cached := h.hash(t.root, true)
return hashed, cached
return h.hash(t.root, true)
}
// Witness returns a set containing all trie nodes that have been accessed.

View file

@ -1330,3 +1330,171 @@ func printSet(set *trienode.NodeSet) string {
}
return out.String()
}
func TestTrieCopy(t *testing.T) {
testTrieCopy(t, []kv{
{k: []byte("do"), v: []byte("verb")},
{k: []byte("ether"), v: []byte("wookiedoo")},
{k: []byte("horse"), v: []byte("stallion")},
{k: []byte("shaman"), v: []byte("horse")},
{k: []byte("doge"), v: []byte("coin")},
{k: []byte("dog"), v: []byte("puppy")},
})
var entries []kv
for i := 0; i < 256; i++ {
entries = append(entries, kv{k: testrand.Bytes(32), v: testrand.Bytes(32)})
}
testTrieCopy(t, entries)
}
func testTrieCopy(t *testing.T, entries []kv) {
tr := NewEmpty(nil)
for _, entry := range entries {
tr.Update(entry.k, entry.v)
}
trCpy := tr.Copy()
if tr.Hash() != trCpy.Hash() {
t.Errorf("Hash mismatch: old %v, copy %v", tr.Hash(), trCpy.Hash())
}
// Check iterator
it, _ := tr.NodeIterator(nil)
itCpy, _ := trCpy.NodeIterator(nil)
for it.Next(false) {
hasNext := itCpy.Next(false)
if !hasNext {
t.Fatal("Iterator is not matched")
}
if !bytes.Equal(it.Path(), itCpy.Path()) {
t.Fatal("Iterator is not matched")
}
if it.Leaf() != itCpy.Leaf() {
t.Fatal("Iterator is not matched")
}
if it.Leaf() && !bytes.Equal(it.LeafBlob(), itCpy.LeafBlob()) {
t.Fatal("Iterator is not matched")
}
}
// Check commit
root, nodes := tr.Commit(false)
rootCpy, nodesCpy := trCpy.Commit(false)
if root != rootCpy {
t.Fatal("root mismatch")
}
if len(nodes.Nodes) != len(nodesCpy.Nodes) {
t.Fatal("commit node mismatch")
}
for p, n := range nodes.Nodes {
nn, exists := nodesCpy.Nodes[p]
if !exists {
t.Fatalf("node not exists: %v", p)
}
if !reflect.DeepEqual(n, nn) {
t.Fatalf("node mismatch: %v", p)
}
}
}
func TestTrieCopyOldTrie(t *testing.T) {
testTrieCopyOldTrie(t, []kv{
{k: []byte("do"), v: []byte("verb")},
{k: []byte("ether"), v: []byte("wookiedoo")},
{k: []byte("horse"), v: []byte("stallion")},
{k: []byte("shaman"), v: []byte("horse")},
{k: []byte("doge"), v: []byte("coin")},
{k: []byte("dog"), v: []byte("puppy")},
})
var entries []kv
for i := 0; i < 256; i++ {
entries = append(entries, kv{k: testrand.Bytes(32), v: testrand.Bytes(32)})
}
testTrieCopyOldTrie(t, entries)
}
func testTrieCopyOldTrie(t *testing.T, entries []kv) {
tr := NewEmpty(nil)
for _, entry := range entries {
tr.Update(entry.k, entry.v)
}
hash := tr.Hash()
trCpy := tr.Copy()
for _, val := range entries {
if rand.Intn(2) == 0 {
trCpy.Delete(val.k)
} else {
trCpy.Update(val.k, testrand.Bytes(32))
}
}
for i := 0; i < 10; i++ {
trCpy.Update(testrand.Bytes(32), testrand.Bytes(32))
}
trCpy.Hash()
trCpy.Commit(false)
// Traverse the original tree, the changes made on the copy one shouldn't
// affect the old one
for _, entry := range entries {
d, _ := tr.Get(entry.k)
if !bytes.Equal(d, entry.v) {
t.Errorf("Unexpected data, key: %v, want: %v, got: %v", entry.k, entry.v, d)
}
}
if tr.Hash() != hash {
t.Errorf("Hash mismatch: old %v, new %v", hash, tr.Hash())
}
}
func TestTrieCopyNewTrie(t *testing.T) {
testTrieCopyNewTrie(t, []kv{
{k: []byte("do"), v: []byte("verb")},
{k: []byte("ether"), v: []byte("wookiedoo")},
{k: []byte("horse"), v: []byte("stallion")},
{k: []byte("shaman"), v: []byte("horse")},
{k: []byte("doge"), v: []byte("coin")},
{k: []byte("dog"), v: []byte("puppy")},
})
var entries []kv
for i := 0; i < 256; i++ {
entries = append(entries, kv{k: testrand.Bytes(32), v: testrand.Bytes(32)})
}
testTrieCopyNewTrie(t, entries)
}
func testTrieCopyNewTrie(t *testing.T, entries []kv) {
tr := NewEmpty(nil)
for _, entry := range entries {
tr.Update(entry.k, entry.v)
}
trCpy := tr.Copy()
hash := trCpy.Hash()
for _, val := range entries {
if rand.Intn(2) == 0 {
tr.Delete(val.k)
} else {
tr.Update(val.k, testrand.Bytes(32))
}
}
for i := 0; i < 10; i++ {
tr.Update(testrand.Bytes(32), testrand.Bytes(32))
}
// Traverse the original tree, the changes made on the copy one shouldn't
// affect the old one
for _, entry := range entries {
d, _ := trCpy.Get(entry.k)
if !bytes.Equal(d, entry.v) {
t.Errorf("Unexpected data, key: %v, want: %v, got: %v", entry.k, entry.v, d)
}
}
if trCpy.Hash() != hash {
t.Errorf("Hash mismatch: old %v, new %v", hash, tr.Hash())
}
}

View file

@ -27,6 +27,8 @@ type NodeReader interface {
// node path and the corresponding node hash. No error will be returned
// if the node is not found.
//
// The returned node content won't be changed after the call.
//
// Don't modify the returned byte slice since it's not deep-copied and
// still be referenced by database.
Node(owner common.Hash, path []byte, hash common.Hash) ([]byte, error)