rlp: add back Iterator.Count, with fixes (#33841)

I removed `Iterator.Count` in #33840, because it appeared to be unused
and did not provide the documented invariant: the returned count should
always be an upper bound on the number of iterations allowed by `Next`.

In order to make `Count` work, the semantics of `CountValues` has to
change to return the number of items up and including the invalid one. I
have reviewed all callsites of `CountValues` to assess if changing this
is safe. There aren't that many, and the only call that doesn't check
the error and return is in the trie node parser,
`trie.decodeNodeUnsafe`. There, we distinguish the node type based on
the number of items, and it previously returned an error for item count
zero. In order to avoid any potential issue that could result from this
change, I'm adding an error check in that function, though it isn't
necessary.
This commit is contained in:
Felix Lange 2026-02-13 23:53:42 +01:00 committed by GitHub
parent 4f38a76438
commit ac85a6f254
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 85 additions and 7 deletions

View file

@ -69,6 +69,15 @@ func (it *Iterator) Value() []byte {
return it.next return it.next
} }
// Count returns the remaining number of items.
// Note this is O(n) and the result may be incorrect if the list data is invalid.
// The returned count is always an upper bound on the remaining items
// that will be visited by the iterator.
func (it *Iterator) Count() int {
count, _ := CountValues(it.data)
return count
}
// Offset returns the offset of the current value into the list data. // Offset returns the offset of the current value into the list data.
func (it *Iterator) Offset() int { func (it *Iterator) Offset() int {
return it.offset - len(it.next) return it.offset - len(it.next)

View file

@ -17,6 +17,7 @@
package rlp package rlp
import ( import (
"io"
"testing" "testing"
"github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/hexutil"
@ -54,6 +55,9 @@ func TestIterator(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if c := txit.Count(); c != 2 {
t.Fatal("wrong Count:", c)
}
var i = 0 var i = 0
for txit.Next() { for txit.Next() {
if txit.err != nil { if txit.err != nil {
@ -65,3 +69,65 @@ func TestIterator(t *testing.T) {
t.Errorf("count wrong, expected %d got %d", i, exp) t.Errorf("count wrong, expected %d got %d", i, exp)
} }
} }
func TestIteratorErrors(t *testing.T) {
tests := []struct {
input []byte
wantCount int // expected Count before iterating
wantErr error
}{
// Second item string header claims 3 bytes content, but only 2 remain.
{unhex("C4 01 83AABB"), 2, ErrValueTooLarge},
// Second item truncated: B9 requires 2 size bytes, none available.
{unhex("C2 01 B9"), 2, io.ErrUnexpectedEOF},
// 0x05 should be encoded directly, not as 81 05.
{unhex("C3 01 8105"), 2, ErrCanonSize},
// Long-form string header B8 used for 1-byte content (< 56).
{unhex("C4 01 B801AA"), 2, ErrCanonSize},
// Long-form list header F8 used for 1-byte content (< 56).
{unhex("C4 01 F80101"), 2, ErrCanonSize},
}
for _, tt := range tests {
it, err := NewListIterator(tt.input)
if err != nil {
t.Fatal("NewListIterator error:", err)
}
if c := it.Count(); c != tt.wantCount {
t.Fatalf("%x: Count = %d, want %d", tt.input, c, tt.wantCount)
}
n := 0
for it.Next() {
if it.Err() != nil {
break
}
n++
}
if wantN := tt.wantCount - 1; n != wantN {
t.Fatalf("%x: got %d valid items, want %d", tt.input, n, wantN)
}
if it.Err() != tt.wantErr {
t.Fatalf("%x: got error %v, want %v", tt.input, it.Err(), tt.wantErr)
}
if it.Next() {
t.Fatalf("%x: Next returned true after error", tt.input)
}
}
}
func FuzzIteratorCount(f *testing.F) {
examples := [][]byte{unhex("010203"), unhex("018142"), unhex("01830202")}
for _, e := range examples {
f.Add(e)
}
f.Fuzz(func(t *testing.T, in []byte) {
it := newIterator(in, 0)
count := it.Count()
i := 0
for it.Next() {
i++
}
if i != count {
t.Fatalf("%x: count %d not equal to %d iterations", in, count, i)
}
})
}

View file

@ -285,7 +285,7 @@ func CountValues(b []byte) (int, error) {
for ; len(b) > 0; i++ { for ; len(b) > 0; i++ {
_, tagsize, size, err := readKind(b) _, tagsize, size, err := readKind(b)
if err != nil { if err != nil {
return 0, err return i + 1, err
} }
b = b[tagsize+size:] b = b[tagsize+size:]
} }

View file

@ -288,9 +288,9 @@ func TestCountValues(t *testing.T) {
{"820101 820202 8403030303 04", 4, nil}, {"820101 820202 8403030303 04", 4, nil},
// size errors // size errors
{"8142", 0, ErrCanonSize}, {"8142", 1, ErrCanonSize},
{"01 01 8142", 0, ErrCanonSize}, {"01 01 8142", 3, ErrCanonSize},
{"02 84020202", 0, ErrValueTooLarge}, {"02 84020202", 2, ErrValueTooLarge},
{ {
input: "A12000BF49F440A1CD0527E4D06E2765654C0F56452257516D793A9B8D604DCFDF2AB853F851808D10000000000000000000000000A056E81F171BCC55A6FF8345E692C0F86E5B48E01B996CADC001622FB5E363B421A0C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470", input: "A12000BF49F440A1CD0527E4D06E2765654C0F56452257516D793A9B8D604DCFDF2AB853F851808D10000000000000000000000000A056E81F171BCC55A6FF8345E692C0F86E5B48E01B996CADC001622FB5E363B421A0C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470",

View file

@ -161,11 +161,14 @@ func decodeNodeUnsafe(hash, buf []byte) (node, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("decode error: %v", err) return nil, fmt.Errorf("decode error: %v", err)
} }
switch c, _ := rlp.CountValues(elems); c { c, err := rlp.CountValues(elems)
case 2: switch {
case err != nil:
return nil, fmt.Errorf("invalid node list: %v", err)
case c == 2:
n, err := decodeShort(hash, elems) n, err := decodeShort(hash, elems)
return n, wrapError(err, "short") return n, wrapError(err, "short")
case 17: case c == 17:
n, err := decodeFull(hash, elems) n, err := decodeFull(hash, elems)
return n, wrapError(err, "full") return n, wrapError(err, "full")
default: default: