rlp/rlpgen: implement package renaming support (#31148)
Some checks failed
/ Linux Build (push) Has been cancelled
/ Linux Build (arm) (push) Has been cancelled
/ Windows Build (push) Has been cancelled
/ Docker Image (push) Has been cancelled

This adds support for importing types from multiple identically-named
packages.

---------

Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
VolodymyrBg 2025-08-01 19:30:48 +03:00 committed by GitHub
parent 038ff766ff
commit 5572f2ed22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 173 additions and 25 deletions

View file

@ -24,6 +24,7 @@ import (
"sort" "sort"
"github.com/ethereum/go-ethereum/rlp/internal/rlpstruct" "github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
"golang.org/x/tools/go/packages"
) )
// buildContext keeps the data needed for make*Op. // buildContext keeps the data needed for make*Op.
@ -96,14 +97,20 @@ func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
// file and assigns unique names of temporary variables. // file and assigns unique names of temporary variables.
type genContext struct { type genContext struct {
inPackage *types.Package inPackage *types.Package
imports map[string]struct{} imports map[string]genImportPackage
tempCounter int tempCounter int
} }
type genImportPackage struct {
alias string
pkg *types.Package
}
func newGenContext(inPackage *types.Package) *genContext { func newGenContext(inPackage *types.Package) *genContext {
return &genContext{ return &genContext{
inPackage: inPackage, inPackage: inPackage,
imports: make(map[string]struct{}), imports: make(map[string]genImportPackage),
tempCounter: 0,
} }
} }
@ -117,32 +124,78 @@ func (ctx *genContext) resetTemp() {
ctx.tempCounter = 0 ctx.tempCounter = 0
} }
func (ctx *genContext) addImport(path string) { func (ctx *genContext) addImportPath(path string) {
if path == ctx.inPackage.Path() { pkg, err := ctx.loadPackage(path)
return // avoid importing the package that we're generating in. if err != nil {
panic(fmt.Sprintf("can't load package %q: %v", path, err))
} }
// TODO: renaming? ctx.addImport(pkg)
ctx.imports[path] = struct{}{}
} }
// importsList returns all packages that need to be imported. func (ctx *genContext) addImport(pkg *types.Package) string {
func (ctx *genContext) importsList() []string { if pkg.Path() == ctx.inPackage.Path() {
imp := make([]string, 0, len(ctx.imports)) return "" // avoid importing the package that we're generating in
for k := range ctx.imports {
imp = append(imp, k)
} }
sort.Strings(imp) if p, exists := ctx.imports[pkg.Path()]; exists {
return imp return p.alias
}
var (
baseName = pkg.Name()
alias = baseName
counter = 1
)
// If the base name conflicts with an existing import, add a numeric suffix.
for ctx.hasAlias(alias) {
alias = fmt.Sprintf("%s%d", baseName, counter)
counter++
}
ctx.imports[pkg.Path()] = genImportPackage{alias, pkg}
return alias
} }
// qualify is the types.Qualifier used for printing types. // hasAlias checks if an alias is already in use
func (ctx *genContext) hasAlias(alias string) bool {
for _, p := range ctx.imports {
if p.alias == alias {
return true
}
}
return false
}
// loadPackage attempts to load package information
func (ctx *genContext) loadPackage(path string) (*types.Package, error) {
cfg := &packages.Config{Mode: packages.NeedName}
pkgs, err := packages.Load(cfg, path)
if err != nil {
return nil, err
}
if len(pkgs) == 0 {
return nil, fmt.Errorf("no package found for path %s", path)
}
return types.NewPackage(path, pkgs[0].Name), nil
}
// qualify is the types.Qualifier used for printing types
func (ctx *genContext) qualify(pkg *types.Package) string { func (ctx *genContext) qualify(pkg *types.Package) string {
if pkg.Path() == ctx.inPackage.Path() { if pkg.Path() == ctx.inPackage.Path() {
return "" return ""
} }
ctx.addImport(pkg.Path()) return ctx.addImport(pkg)
// TODO: renaming? }
return pkg.Name()
// importsList returns all packages that need to be imported
func (ctx *genContext) importsList() []string {
imp := make([]string, 0, len(ctx.imports))
for path, p := range ctx.imports {
if p.alias == p.pkg.Name() {
imp = append(imp, fmt.Sprintf("%q", path))
} else {
imp = append(imp, fmt.Sprintf("%s %q", p.alias, path))
}
}
sort.Strings(imp)
return imp
} }
type op interface { type op interface {
@ -359,7 +412,7 @@ func (op uint256Op) genWrite(ctx *genContext, v string) string {
} }
func (op uint256Op) genDecode(ctx *genContext) (string, string) { func (op uint256Op) genDecode(ctx *genContext) (string, string) {
ctx.addImport("github.com/holiman/uint256") ctx.addImportPath("github.com/holiman/uint256")
var b bytes.Buffer var b bytes.Buffer
resultV := ctx.temp() resultV := ctx.temp()
@ -732,7 +785,7 @@ func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstru
// generateDecoder generates the DecodeRLP method on 'typ'. // generateDecoder generates the DecodeRLP method on 'typ'.
func generateDecoder(ctx *genContext, typ string, op op) []byte { func generateDecoder(ctx *genContext, typ string, op op) []byte {
ctx.resetTemp() ctx.resetTemp()
ctx.addImport(pathOfPackageRLP) ctx.addImportPath(pathOfPackageRLP)
result, code := op.genDecode(ctx) result, code := op.genDecode(ctx)
var b bytes.Buffer var b bytes.Buffer
@ -747,8 +800,8 @@ func generateDecoder(ctx *genContext, typ string, op op) []byte {
// generateEncoder generates the EncodeRLP method on 'typ'. // generateEncoder generates the EncodeRLP method on 'typ'.
func generateEncoder(ctx *genContext, typ string, op op) []byte { func generateEncoder(ctx *genContext, typ string, op op) []byte {
ctx.resetTemp() ctx.resetTemp()
ctx.addImport("io") ctx.addImportPath("io")
ctx.addImport(pathOfPackageRLP) ctx.addImportPath(pathOfPackageRLP)
var b bytes.Buffer var b bytes.Buffer
fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ) fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ)
@ -783,7 +836,7 @@ func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]b
var b bytes.Buffer var b bytes.Buffer
fmt.Fprintf(&b, "package %s\n\n", pkg.Name()) fmt.Fprintf(&b, "package %s\n\n", pkg.Name())
for _, imp := range ctx.importsList() { for _, imp := range ctx.importsList() {
fmt.Fprintf(&b, "import %q\n", imp) fmt.Fprintf(&b, "import %s\n", imp)
} }
if encoder { if encoder {
fmt.Fprintln(&b) fmt.Fprintln(&b)

View file

@ -47,7 +47,7 @@ func init() {
} }
} }
var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256"} var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256", "pkgclash"}
func TestOutput(t *testing.T) { func TestOutput(t *testing.T) {
for _, test := range tests { for _, test := range tests {

13
rlp/rlpgen/testdata/pkgclash.in.txt vendored Normal file
View file

@ -0,0 +1,13 @@
// -*- mode: go -*-
package test
import (
eth1 "github.com/ethereum/go-ethereum/eth"
eth2 "github.com/ethereum/go-ethereum/eth/protocols/eth"
)
type Test struct {
A eth1.MinerAPI
B eth2.GetReceiptsPacket
}

82
rlp/rlpgen/testdata/pkgclash.out.txt vendored Normal file
View file

@ -0,0 +1,82 @@
package test
import "github.com/ethereum/go-ethereum/common"
import "github.com/ethereum/go-ethereum/eth"
import "github.com/ethereum/go-ethereum/rlp"
import "io"
import eth1 "github.com/ethereum/go-ethereum/eth/protocols/eth"
func (obj *Test) EncodeRLP(_w io.Writer) error {
w := rlp.NewEncoderBuffer(_w)
_tmp0 := w.List()
_tmp1 := w.List()
w.ListEnd(_tmp1)
_tmp2 := w.List()
w.WriteUint64(obj.B.RequestId)
_tmp3 := w.List()
for _, _tmp4 := range obj.B.GetReceiptsRequest {
w.WriteBytes(_tmp4[:])
}
w.ListEnd(_tmp3)
w.ListEnd(_tmp2)
w.ListEnd(_tmp0)
return w.Flush()
}
func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
var _tmp0 Test
{
if _, err := dec.List(); err != nil {
return err
}
// A:
var _tmp1 eth.MinerAPI
{
if _, err := dec.List(); err != nil {
return err
}
if err := dec.ListEnd(); err != nil {
return err
}
}
_tmp0.A = _tmp1
// B:
var _tmp2 eth1.GetReceiptsPacket
{
if _, err := dec.List(); err != nil {
return err
}
// RequestId:
_tmp3, err := dec.Uint64()
if err != nil {
return err
}
_tmp2.RequestId = _tmp3
// GetReceiptsRequest:
var _tmp4 []common.Hash
if _, err := dec.List(); err != nil {
return err
}
for dec.MoreDataInList() {
var _tmp5 common.Hash
if err := dec.ReadBytes(_tmp5[:]); err != nil {
return err
}
_tmp4 = append(_tmp4, _tmp5)
}
if err := dec.ListEnd(); err != nil {
return err
}
_tmp2.GetReceiptsRequest = _tmp4
if err := dec.ListEnd(); err != nil {
return err
}
}
_tmp0.B = _tmp2
if err := dec.ListEnd(); err != nil {
return err
}
}
*obj = _tmp0
return nil
}