diff --git a/trie/hasher.go b/trie/hasher.go index 16606808c9..a2a1f5b662 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -17,6 +17,8 @@ package trie import ( + "bytes" + "fmt" "sync" "github.com/ethereum/go-ethereum/crypto" @@ -54,7 +56,7 @@ func returnHasherToPool(h *hasher) { } // hash collapses a node down into a hash node. -func (h *hasher) hash(n node, force bool) node { +func (h *hasher) hash(n node, force bool) []byte { // Return the cached hash if it's available if hash, _ := n.cache(); hash != nil { return hash @@ -62,101 +64,110 @@ func (h *hasher) hash(n node, force bool) node { // Trie not processed yet, walk the children switch n := n.(type) { case *shortNode: - collapsed := h.hashShortNodeChildren(n) - hashed := h.shortnodeToHash(collapsed, force) - if hn, ok := hashed.(hashNode); ok { - n.flags.hash = hn - } else { - n.flags.hash = nil + enc := h.encodeShortNode(n) + if len(enc) < 32 && !force { + // Nodes smaller than 32 bytes are embedded directly in their parent. + // In such cases, return the raw encoded blob instead of the node hash. + // It's essential to deep-copy the node blob, as the underlying buffer + // of enc will be reused later. + buf := make([]byte, len(enc)) + copy(buf, enc) + return buf } - return hashed + hash := h.hashData(enc) + n.flags.hash = hash + return hash + case *fullNode: - collapsed := h.hashFullNodeChildren(n) - hashed := h.fullnodeToHash(collapsed, force) - if hn, ok := hashed.(hashNode); ok { - n.flags.hash = hn - } else { - n.flags.hash = nil + enc := h.encodeFullNode(n) + if len(enc) < 32 && !force { + // Nodes smaller than 32 bytes are embedded directly in their parent. + // In such cases, return the raw encoded blob instead of the node hash. + // It's essential to deep-copy the node blob, as the underlying buffer + // of enc will be reused later. + buf := make([]byte, len(enc)) + copy(buf, enc) + return buf } - return hashed - default: - // Value and hash nodes don't have children, so they're left as were + hash := h.hashData(enc) + n.flags.hash = hash + return hash + + case hashNode: + // hash nodes don't have children, so they're left as were return n - } -} -// hashShortNodeChildren returns a copy of the supplied shortNode, with its child -// being replaced by either the hash or an embedded node if the child is small. -func (h *hasher) hashShortNodeChildren(n *shortNode) *shortNode { - var collapsed shortNode - collapsed.Key = hexToCompact(n.Key) - switch n.Val.(type) { - case *fullNode, *shortNode: - collapsed.Val = h.hash(n.Val, false) default: - collapsed.Val = n.Val + panic(fmt.Errorf("unexpected node type, %T", n)) } - return &collapsed } -// hashFullNodeChildren returns a copy of the supplied fullNode, with its child -// being replaced by either the hash or an embedded node if the child is small. -func (h *hasher) hashFullNodeChildren(n *fullNode) *fullNode { - var children [17]node +// encodeShortNode encodes the provided shortNode into the bytes. Notably, the +// return slice must be deep-copied explicitly, otherwise the underlying slice +// will be reused later. +func (h *hasher) encodeShortNode(n *shortNode) []byte { + // Encode leaf node + if hasTerm(n.Key) { + var ln leafNodeEncoder + ln.Key = hexToCompact(n.Key) + ln.Val = n.Val.(valueNode) + ln.encode(h.encbuf) + return h.encodedBytes() + } + // Encode extension node + var en extNodeEncoder + en.Key = hexToCompact(n.Key) + en.Val = h.hash(n.Val, false) + en.encode(h.encbuf) + return h.encodedBytes() +} + +// fnEncoderPool is the pool for storing shared fullNode encoder to mitigate +// the significant memory allocation overhead. +var fnEncoderPool = sync.Pool{ + New: func() interface{} { + var enc fullnodeEncoder + return &enc + }, +} + +// encodeFullNode encodes the provided fullNode into the bytes. Notably, the +// return slice must be deep-copied explicitly, otherwise the underlying slice +// will be reused later. +func (h *hasher) encodeFullNode(n *fullNode) []byte { + fn := fnEncoderPool.Get().(*fullnodeEncoder) + fn.reset() + if h.parallel { var wg sync.WaitGroup for i := 0; i < 16; i++ { - if child := n.Children[i]; child != nil { - wg.Add(1) - go func(i int) { - hasher := newHasher(false) - children[i] = hasher.hash(child, false) - returnHasherToPool(hasher) - wg.Done() - }(i) - } else { - children[i] = nilValueNode + if n.Children[i] == nil { + continue } + wg.Add(1) + go func(i int) { + defer wg.Done() + + h := newHasher(false) + fn.Children[i] = h.hash(n.Children[i], false) + returnHasherToPool(h) + }(i) } wg.Wait() } else { for i := 0; i < 16; i++ { if child := n.Children[i]; child != nil { - children[i] = h.hash(child, false) - } else { - children[i] = nilValueNode + fn.Children[i] = h.hash(child, false) } } } if n.Children[16] != nil { - children[16] = n.Children[16] + fn.Children[16] = n.Children[16].(valueNode) } - return &fullNode{flags: nodeFlag{}, Children: children} -} + fn.encode(h.encbuf) + fnEncoderPool.Put(fn) -// shortNodeToHash computes the hash of the given shortNode. The shortNode must -// first be collapsed, with its key converted to compact form. If the RLP-encoded -// node data is smaller than 32 bytes, the node itself is returned. -func (h *hasher) shortnodeToHash(n *shortNode, force bool) node { - n.encode(h.encbuf) - enc := h.encodedBytes() - - if len(enc) < 32 && !force { - return n // Nodes smaller than 32 bytes are stored inside their parent - } - return h.hashData(enc) -} - -// fullnodeToHash computes the hash of the given fullNode. If the RLP-encoded -// node data is smaller than 32 bytes, the node itself is returned. -func (h *hasher) fullnodeToHash(n *fullNode, force bool) node { - n.encode(h.encbuf) - enc := h.encodedBytes() - - if len(enc) < 32 && !force { - return n // Nodes smaller than 32 bytes are stored inside their parent - } - return h.hashData(enc) + return h.encodedBytes() } // encodedBytes returns the result of the last encoding operation on h.encbuf. @@ -175,9 +186,10 @@ func (h *hasher) encodedBytes() []byte { return h.tmp } -// hashData hashes the provided data -func (h *hasher) hashData(data []byte) hashNode { - n := make(hashNode, 32) +// hashData hashes the provided data. It is safe to modify the returned slice after +// the function returns. +func (h *hasher) hashData(data []byte) []byte { + n := make([]byte, 32) h.sha.Reset() h.sha.Write(data) h.sha.Read(n) @@ -192,20 +204,17 @@ func (h *hasher) hashDataTo(dst, data []byte) { h.sha.Read(dst) } -// proofHash is used to construct trie proofs, and returns the 'collapsed' -// node (for later RLP encoding) as well as the hashed node -- unless the -// node is smaller than 32 bytes, in which case it will be returned as is. -// This method does not do anything on value- or hash-nodes. -func (h *hasher) proofHash(original node) (collapsed, hashed node) { +// proofHash is used to construct trie proofs, returning the rlp-encoded node blobs. +// Note, only resolved node (shortNode or fullNode) is expected for proofing. +// +// It is safe to modify the returned slice after the function returns. +func (h *hasher) proofHash(original node) []byte { switch n := original.(type) { case *shortNode: - sn := h.hashShortNodeChildren(n) - return sn, h.shortnodeToHash(sn, false) + return bytes.Clone(h.encodeShortNode(n)) case *fullNode: - fn := h.hashFullNodeChildren(n) - return fn, h.fullnodeToHash(fn, false) + return bytes.Clone(h.encodeFullNode(n)) default: - // Value and hash nodes don't have children, so they're left as were - return n, n + panic(fmt.Errorf("unexpected node type, %T", original)) } } diff --git a/trie/iterator.go b/trie/iterator.go index fa01611063..e6fedf2430 100644 --- a/trie/iterator.go +++ b/trie/iterator.go @@ -240,9 +240,9 @@ func (it *nodeIterator) LeafProof() [][]byte { for i, item := range it.stack[:len(it.stack)-1] { // Gather nodes that end up as hash nodes (or the root) - node, hashed := hasher.proofHash(item.node) - if _, ok := hashed.(hashNode); ok || i == 0 { - proofs = append(proofs, nodeToBytes(node)) + enc := hasher.proofHash(item.node) + if len(enc) >= 32 || i == 0 { + proofs = append(proofs, enc) } } return proofs diff --git a/trie/node.go b/trie/node.go index 96f077ebbb..74fac4fd4e 100644 --- a/trie/node.go +++ b/trie/node.go @@ -68,10 +68,6 @@ type ( } ) -// nilValueNode is used when collapsing internal trie nodes for hashing, since -// unset children need to serialize correctly. -var nilValueNode = valueNode(nil) - // EncodeRLP encodes a full node into the consensus RLP format. func (n *fullNode) EncodeRLP(w io.Writer) error { eb := rlp.NewEncoderBuffer(w) diff --git a/trie/node_enc.go b/trie/node_enc.go index c95587eeab..02b93ee6f3 100644 --- a/trie/node_enc.go +++ b/trie/node_enc.go @@ -42,18 +42,29 @@ func (n *fullNode) encode(w rlp.EncoderBuffer) { func (n *fullnodeEncoder) encode(w rlp.EncoderBuffer) { offset := w.List() - for _, c := range n.Children { - if c == nil { + for i, c := range n.Children { + if len(c) == 0 { w.Write(rlp.EmptyString) - } else if len(c) < 32 { - w.Write(c) // rawNode } else { - w.WriteBytes(c) // hashNode + // valueNode or hashNode + if i == 16 || len(c) >= 32 { + w.WriteBytes(c) + } else { + w.Write(c) // rawNode + } } } w.ListEnd(offset) } +func (n *fullnodeEncoder) reset() { + for i, c := range n.Children { + if len(c) != 0 { + n.Children[i] = n.Children[i][:0] + } + } +} + func (n *shortNode) encode(w rlp.EncoderBuffer) { offset := w.List() w.WriteBytes(n.Key) @@ -70,7 +81,7 @@ func (n *extNodeEncoder) encode(w rlp.EncoderBuffer) { w.WriteBytes(n.Key) if n.Val == nil { - w.Write(rlp.EmptyString) + w.Write(rlp.EmptyString) // theoretically impossible to happen } else if len(n.Val) < 32 { w.Write(n.Val) // rawNode } else { diff --git a/trie/proof.go b/trie/proof.go index 751d6f620f..53b7acc30c 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -22,6 +22,7 @@ import ( "fmt" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" ) @@ -85,16 +86,9 @@ func (t *Trie) Prove(key []byte, proofDb ethdb.KeyValueWriter) error { defer returnHasherToPool(hasher) for i, n := range nodes { - var hn node - n, hn = hasher.proofHash(n) - if hash, ok := hn.(hashNode); ok || i == 0 { - // If the node's database encoding is a hash (or is the - // root node), it becomes a proof element. - enc := nodeToBytes(n) - if !ok { - hash = hasher.hashData(enc) - } - proofDb.Put(hash, enc) + enc := hasher.proofHash(n) + if len(enc) >= 32 || i == 0 { + proofDb.Put(crypto.Keccak256(enc), enc) } } return nil diff --git a/trie/trie.go b/trie/trie.go index fdb4da9be4..222bf8b1f0 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -626,7 +626,7 @@ func (t *Trie) resolveAndTrack(n hashNode, prefix []byte) (node, error) { // Hash returns the root hash of the trie. It does not write to the // database and can be used even if the trie doesn't have one. func (t *Trie) Hash() common.Hash { - return common.BytesToHash(t.hashRoot().(hashNode)) + return common.BytesToHash(t.hashRoot()) } // Commit collects all dirty nodes in the trie and replaces them with the @@ -677,9 +677,9 @@ func (t *Trie) Commit(collectLeaf bool) (common.Hash, *trienode.NodeSet) { } // hashRoot calculates the root hash of the given trie -func (t *Trie) hashRoot() node { +func (t *Trie) hashRoot() []byte { if t.root == nil { - return hashNode(types.EmptyRootHash.Bytes()) + return types.EmptyRootHash.Bytes() } // If the number of changes is below 100, we let one thread handle it h := newHasher(t.unhashed >= 100) diff --git a/trie/trie_test.go b/trie/trie_test.go index b806ae6b0c..edd85677fe 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -863,7 +863,6 @@ func (s *spongeDb) Flush() { s.sponge.Write([]byte(key)) s.sponge.Write([]byte(s.values[key])) } - fmt.Println(len(s.keys)) } // spongeBatch is a dummy batch which immediately writes to the underlying spongedb