mirror of
https://github.com/ethereum/go-ethereum.git
synced 2026-02-26 07:37:20 +00:00
rlp: add back Iterator.Count, with fixes (#33841)
I removed `Iterator.Count` in #33840, because it appeared to be unused and did not provide the documented invariant: the returned count should always be an upper bound on the number of iterations allowed by `Next`. In order to make `Count` work, the semantics of `CountValues` has to change to return the number of items up and including the invalid one. I have reviewed all callsites of `CountValues` to assess if changing this is safe. There aren't that many, and the only call that doesn't check the error and return is in the trie node parser, `trie.decodeNodeUnsafe`. There, we distinguish the node type based on the number of items, and it previously returned an error for item count zero. In order to avoid any potential issue that could result from this change, I'm adding an error check in that function, though it isn't necessary.
This commit is contained in:
parent
4f38a76438
commit
ac85a6f254
5 changed files with 85 additions and 7 deletions
|
|
@ -69,6 +69,15 @@ func (it *Iterator) Value() []byte {
|
||||||
return it.next
|
return it.next
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Count returns the remaining number of items.
|
||||||
|
// Note this is O(n) and the result may be incorrect if the list data is invalid.
|
||||||
|
// The returned count is always an upper bound on the remaining items
|
||||||
|
// that will be visited by the iterator.
|
||||||
|
func (it *Iterator) Count() int {
|
||||||
|
count, _ := CountValues(it.data)
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
// Offset returns the offset of the current value into the list data.
|
// Offset returns the offset of the current value into the list data.
|
||||||
func (it *Iterator) Offset() int {
|
func (it *Iterator) Offset() int {
|
||||||
return it.offset - len(it.next)
|
return it.offset - len(it.next)
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
package rlp
|
package rlp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/common/hexutil"
|
"github.com/ethereum/go-ethereum/common/hexutil"
|
||||||
|
|
@ -54,6 +55,9 @@ func TestIterator(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
if c := txit.Count(); c != 2 {
|
||||||
|
t.Fatal("wrong Count:", c)
|
||||||
|
}
|
||||||
var i = 0
|
var i = 0
|
||||||
for txit.Next() {
|
for txit.Next() {
|
||||||
if txit.err != nil {
|
if txit.err != nil {
|
||||||
|
|
@ -65,3 +69,65 @@ func TestIterator(t *testing.T) {
|
||||||
t.Errorf("count wrong, expected %d got %d", i, exp)
|
t.Errorf("count wrong, expected %d got %d", i, exp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIteratorErrors(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input []byte
|
||||||
|
wantCount int // expected Count before iterating
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
// Second item string header claims 3 bytes content, but only 2 remain.
|
||||||
|
{unhex("C4 01 83AABB"), 2, ErrValueTooLarge},
|
||||||
|
// Second item truncated: B9 requires 2 size bytes, none available.
|
||||||
|
{unhex("C2 01 B9"), 2, io.ErrUnexpectedEOF},
|
||||||
|
// 0x05 should be encoded directly, not as 81 05.
|
||||||
|
{unhex("C3 01 8105"), 2, ErrCanonSize},
|
||||||
|
// Long-form string header B8 used for 1-byte content (< 56).
|
||||||
|
{unhex("C4 01 B801AA"), 2, ErrCanonSize},
|
||||||
|
// Long-form list header F8 used for 1-byte content (< 56).
|
||||||
|
{unhex("C4 01 F80101"), 2, ErrCanonSize},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
it, err := NewListIterator(tt.input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("NewListIterator error:", err)
|
||||||
|
}
|
||||||
|
if c := it.Count(); c != tt.wantCount {
|
||||||
|
t.Fatalf("%x: Count = %d, want %d", tt.input, c, tt.wantCount)
|
||||||
|
}
|
||||||
|
n := 0
|
||||||
|
for it.Next() {
|
||||||
|
if it.Err() != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
if wantN := tt.wantCount - 1; n != wantN {
|
||||||
|
t.Fatalf("%x: got %d valid items, want %d", tt.input, n, wantN)
|
||||||
|
}
|
||||||
|
if it.Err() != tt.wantErr {
|
||||||
|
t.Fatalf("%x: got error %v, want %v", tt.input, it.Err(), tt.wantErr)
|
||||||
|
}
|
||||||
|
if it.Next() {
|
||||||
|
t.Fatalf("%x: Next returned true after error", tt.input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func FuzzIteratorCount(f *testing.F) {
|
||||||
|
examples := [][]byte{unhex("010203"), unhex("018142"), unhex("01830202")}
|
||||||
|
for _, e := range examples {
|
||||||
|
f.Add(e)
|
||||||
|
}
|
||||||
|
f.Fuzz(func(t *testing.T, in []byte) {
|
||||||
|
it := newIterator(in, 0)
|
||||||
|
count := it.Count()
|
||||||
|
i := 0
|
||||||
|
for it.Next() {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i != count {
|
||||||
|
t.Fatalf("%x: count %d not equal to %d iterations", in, count, i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -285,7 +285,7 @@ func CountValues(b []byte) (int, error) {
|
||||||
for ; len(b) > 0; i++ {
|
for ; len(b) > 0; i++ {
|
||||||
_, tagsize, size, err := readKind(b)
|
_, tagsize, size, err := readKind(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return i + 1, err
|
||||||
}
|
}
|
||||||
b = b[tagsize+size:]
|
b = b[tagsize+size:]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -288,9 +288,9 @@ func TestCountValues(t *testing.T) {
|
||||||
{"820101 820202 8403030303 04", 4, nil},
|
{"820101 820202 8403030303 04", 4, nil},
|
||||||
|
|
||||||
// size errors
|
// size errors
|
||||||
{"8142", 0, ErrCanonSize},
|
{"8142", 1, ErrCanonSize},
|
||||||
{"01 01 8142", 0, ErrCanonSize},
|
{"01 01 8142", 3, ErrCanonSize},
|
||||||
{"02 84020202", 0, ErrValueTooLarge},
|
{"02 84020202", 2, ErrValueTooLarge},
|
||||||
|
|
||||||
{
|
{
|
||||||
input: "A12000BF49F440A1CD0527E4D06E2765654C0F56452257516D793A9B8D604DCFDF2AB853F851808D10000000000000000000000000A056E81F171BCC55A6FF8345E692C0F86E5B48E01B996CADC001622FB5E363B421A0C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470",
|
input: "A12000BF49F440A1CD0527E4D06E2765654C0F56452257516D793A9B8D604DCFDF2AB853F851808D10000000000000000000000000A056E81F171BCC55A6FF8345E692C0F86E5B48E01B996CADC001622FB5E363B421A0C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470",
|
||||||
|
|
|
||||||
|
|
@ -161,11 +161,14 @@ func decodeNodeUnsafe(hash, buf []byte) (node, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("decode error: %v", err)
|
return nil, fmt.Errorf("decode error: %v", err)
|
||||||
}
|
}
|
||||||
switch c, _ := rlp.CountValues(elems); c {
|
c, err := rlp.CountValues(elems)
|
||||||
case 2:
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return nil, fmt.Errorf("invalid node list: %v", err)
|
||||||
|
case c == 2:
|
||||||
n, err := decodeShort(hash, elems)
|
n, err := decodeShort(hash, elems)
|
||||||
return n, wrapError(err, "short")
|
return n, wrapError(err, "short")
|
||||||
case 17:
|
case c == 17:
|
||||||
n, err := decodeFull(hash, elems)
|
n, err := decodeFull(hash, elems)
|
||||||
return n, wrapError(err, "full")
|
return n, wrapError(err, "full")
|
||||||
default:
|
default:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue