From ac85a6f25449a7b96505637e18c988b3142d47bf Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Fri, 13 Feb 2026 23:53:42 +0100 Subject: [PATCH] rlp: add back Iterator.Count, with fixes (#33841) I removed `Iterator.Count` in #33840, because it appeared to be unused and did not provide the documented invariant: the returned count should always be an upper bound on the number of iterations allowed by `Next`. In order to make `Count` work, the semantics of `CountValues` has to change to return the number of items up and including the invalid one. I have reviewed all callsites of `CountValues` to assess if changing this is safe. There aren't that many, and the only call that doesn't check the error and return is in the trie node parser, `trie.decodeNodeUnsafe`. There, we distinguish the node type based on the number of items, and it previously returned an error for item count zero. In order to avoid any potential issue that could result from this change, I'm adding an error check in that function, though it isn't necessary. --- rlp/iterator.go | 9 ++++++ rlp/iterator_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++++ rlp/raw.go | 2 +- rlp/raw_test.go | 6 ++-- trie/node.go | 9 ++++-- 5 files changed, 85 insertions(+), 7 deletions(-) diff --git a/rlp/iterator.go b/rlp/iterator.go index d6151eabd7..9e41cec947 100644 --- a/rlp/iterator.go +++ b/rlp/iterator.go @@ -69,6 +69,15 @@ func (it *Iterator) Value() []byte { return it.next } +// Count returns the remaining number of items. +// Note this is O(n) and the result may be incorrect if the list data is invalid. +// The returned count is always an upper bound on the remaining items +// that will be visited by the iterator. +func (it *Iterator) Count() int { + count, _ := CountValues(it.data) + return count +} + // Offset returns the offset of the current value into the list data. func (it *Iterator) Offset() int { return it.offset - len(it.next) diff --git a/rlp/iterator_test.go b/rlp/iterator_test.go index 8ea26dad1b..275d4371c7 100644 --- a/rlp/iterator_test.go +++ b/rlp/iterator_test.go @@ -17,6 +17,7 @@ package rlp import ( + "io" "testing" "github.com/ethereum/go-ethereum/common/hexutil" @@ -54,6 +55,9 @@ func TestIterator(t *testing.T) { if err != nil { t.Fatal(err) } + if c := txit.Count(); c != 2 { + t.Fatal("wrong Count:", c) + } var i = 0 for txit.Next() { if txit.err != nil { @@ -65,3 +69,65 @@ func TestIterator(t *testing.T) { t.Errorf("count wrong, expected %d got %d", i, exp) } } + +func TestIteratorErrors(t *testing.T) { + tests := []struct { + input []byte + wantCount int // expected Count before iterating + wantErr error + }{ + // Second item string header claims 3 bytes content, but only 2 remain. + {unhex("C4 01 83AABB"), 2, ErrValueTooLarge}, + // Second item truncated: B9 requires 2 size bytes, none available. + {unhex("C2 01 B9"), 2, io.ErrUnexpectedEOF}, + // 0x05 should be encoded directly, not as 81 05. + {unhex("C3 01 8105"), 2, ErrCanonSize}, + // Long-form string header B8 used for 1-byte content (< 56). + {unhex("C4 01 B801AA"), 2, ErrCanonSize}, + // Long-form list header F8 used for 1-byte content (< 56). + {unhex("C4 01 F80101"), 2, ErrCanonSize}, + } + for _, tt := range tests { + it, err := NewListIterator(tt.input) + if err != nil { + t.Fatal("NewListIterator error:", err) + } + if c := it.Count(); c != tt.wantCount { + t.Fatalf("%x: Count = %d, want %d", tt.input, c, tt.wantCount) + } + n := 0 + for it.Next() { + if it.Err() != nil { + break + } + n++ + } + if wantN := tt.wantCount - 1; n != wantN { + t.Fatalf("%x: got %d valid items, want %d", tt.input, n, wantN) + } + if it.Err() != tt.wantErr { + t.Fatalf("%x: got error %v, want %v", tt.input, it.Err(), tt.wantErr) + } + if it.Next() { + t.Fatalf("%x: Next returned true after error", tt.input) + } + } +} + +func FuzzIteratorCount(f *testing.F) { + examples := [][]byte{unhex("010203"), unhex("018142"), unhex("01830202")} + for _, e := range examples { + f.Add(e) + } + f.Fuzz(func(t *testing.T, in []byte) { + it := newIterator(in, 0) + count := it.Count() + i := 0 + for it.Next() { + i++ + } + if i != count { + t.Fatalf("%x: count %d not equal to %d iterations", in, count, i) + } + }) +} diff --git a/rlp/raw.go b/rlp/raw.go index 69946f6d22..08ec667158 100644 --- a/rlp/raw.go +++ b/rlp/raw.go @@ -285,7 +285,7 @@ func CountValues(b []byte) (int, error) { for ; len(b) > 0; i++ { _, tagsize, size, err := readKind(b) if err != nil { - return 0, err + return i + 1, err } b = b[tagsize+size:] } diff --git a/rlp/raw_test.go b/rlp/raw_test.go index 49ac3dcc62..112c5d7897 100644 --- a/rlp/raw_test.go +++ b/rlp/raw_test.go @@ -288,9 +288,9 @@ func TestCountValues(t *testing.T) { {"820101 820202 8403030303 04", 4, nil}, // size errors - {"8142", 0, ErrCanonSize}, - {"01 01 8142", 0, ErrCanonSize}, - {"02 84020202", 0, ErrValueTooLarge}, + {"8142", 1, ErrCanonSize}, + {"01 01 8142", 3, ErrCanonSize}, + {"02 84020202", 2, ErrValueTooLarge}, { input: "A12000BF49F440A1CD0527E4D06E2765654C0F56452257516D793A9B8D604DCFDF2AB853F851808D10000000000000000000000000A056E81F171BCC55A6FF8345E692C0F86E5B48E01B996CADC001622FB5E363B421A0C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470", diff --git a/trie/node.go b/trie/node.go index bff2d0516a..7022116048 100644 --- a/trie/node.go +++ b/trie/node.go @@ -161,11 +161,14 @@ func decodeNodeUnsafe(hash, buf []byte) (node, error) { if err != nil { return nil, fmt.Errorf("decode error: %v", err) } - switch c, _ := rlp.CountValues(elems); c { - case 2: + c, err := rlp.CountValues(elems) + switch { + case err != nil: + return nil, fmt.Errorf("invalid node list: %v", err) + case c == 2: n, err := decodeShort(hash, elems) return n, wrapError(err, "short") - case 17: + case c == 17: n, err := decodeFull(hash, elems) return n, wrapError(err, "full") default: