mirror of
https://github.com/ethereum/go-ethereum.git
synced 2026-02-26 07:37:20 +00:00
rlp/rlpgen: implement package renaming support (#31148)
This adds support for importing types from multiple identically-named packages. --------- Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
parent
038ff766ff
commit
5572f2ed22
4 changed files with 173 additions and 25 deletions
|
|
@ -24,6 +24,7 @@ import (
|
|||
"sort"
|
||||
|
||||
"github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
|
||||
"golang.org/x/tools/go/packages"
|
||||
)
|
||||
|
||||
// buildContext keeps the data needed for make*Op.
|
||||
|
|
@ -96,14 +97,20 @@ func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
|
|||
// file and assigns unique names of temporary variables.
|
||||
type genContext struct {
|
||||
inPackage *types.Package
|
||||
imports map[string]struct{}
|
||||
imports map[string]genImportPackage
|
||||
tempCounter int
|
||||
}
|
||||
|
||||
type genImportPackage struct {
|
||||
alias string
|
||||
pkg *types.Package
|
||||
}
|
||||
|
||||
func newGenContext(inPackage *types.Package) *genContext {
|
||||
return &genContext{
|
||||
inPackage: inPackage,
|
||||
imports: make(map[string]struct{}),
|
||||
inPackage: inPackage,
|
||||
imports: make(map[string]genImportPackage),
|
||||
tempCounter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -117,32 +124,78 @@ func (ctx *genContext) resetTemp() {
|
|||
ctx.tempCounter = 0
|
||||
}
|
||||
|
||||
func (ctx *genContext) addImport(path string) {
|
||||
if path == ctx.inPackage.Path() {
|
||||
return // avoid importing the package that we're generating in.
|
||||
func (ctx *genContext) addImportPath(path string) {
|
||||
pkg, err := ctx.loadPackage(path)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("can't load package %q: %v", path, err))
|
||||
}
|
||||
// TODO: renaming?
|
||||
ctx.imports[path] = struct{}{}
|
||||
ctx.addImport(pkg)
|
||||
}
|
||||
|
||||
// importsList returns all packages that need to be imported.
|
||||
func (ctx *genContext) importsList() []string {
|
||||
imp := make([]string, 0, len(ctx.imports))
|
||||
for k := range ctx.imports {
|
||||
imp = append(imp, k)
|
||||
func (ctx *genContext) addImport(pkg *types.Package) string {
|
||||
if pkg.Path() == ctx.inPackage.Path() {
|
||||
return "" // avoid importing the package that we're generating in
|
||||
}
|
||||
sort.Strings(imp)
|
||||
return imp
|
||||
if p, exists := ctx.imports[pkg.Path()]; exists {
|
||||
return p.alias
|
||||
}
|
||||
var (
|
||||
baseName = pkg.Name()
|
||||
alias = baseName
|
||||
counter = 1
|
||||
)
|
||||
// If the base name conflicts with an existing import, add a numeric suffix.
|
||||
for ctx.hasAlias(alias) {
|
||||
alias = fmt.Sprintf("%s%d", baseName, counter)
|
||||
counter++
|
||||
}
|
||||
ctx.imports[pkg.Path()] = genImportPackage{alias, pkg}
|
||||
return alias
|
||||
}
|
||||
|
||||
// qualify is the types.Qualifier used for printing types.
|
||||
// hasAlias checks if an alias is already in use
|
||||
func (ctx *genContext) hasAlias(alias string) bool {
|
||||
for _, p := range ctx.imports {
|
||||
if p.alias == alias {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// loadPackage attempts to load package information
|
||||
func (ctx *genContext) loadPackage(path string) (*types.Package, error) {
|
||||
cfg := &packages.Config{Mode: packages.NeedName}
|
||||
pkgs, err := packages.Load(cfg, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(pkgs) == 0 {
|
||||
return nil, fmt.Errorf("no package found for path %s", path)
|
||||
}
|
||||
return types.NewPackage(path, pkgs[0].Name), nil
|
||||
}
|
||||
|
||||
// qualify is the types.Qualifier used for printing types
|
||||
func (ctx *genContext) qualify(pkg *types.Package) string {
|
||||
if pkg.Path() == ctx.inPackage.Path() {
|
||||
return ""
|
||||
}
|
||||
ctx.addImport(pkg.Path())
|
||||
// TODO: renaming?
|
||||
return pkg.Name()
|
||||
return ctx.addImport(pkg)
|
||||
}
|
||||
|
||||
// importsList returns all packages that need to be imported
|
||||
func (ctx *genContext) importsList() []string {
|
||||
imp := make([]string, 0, len(ctx.imports))
|
||||
for path, p := range ctx.imports {
|
||||
if p.alias == p.pkg.Name() {
|
||||
imp = append(imp, fmt.Sprintf("%q", path))
|
||||
} else {
|
||||
imp = append(imp, fmt.Sprintf("%s %q", p.alias, path))
|
||||
}
|
||||
}
|
||||
sort.Strings(imp)
|
||||
return imp
|
||||
}
|
||||
|
||||
type op interface {
|
||||
|
|
@ -359,7 +412,7 @@ func (op uint256Op) genWrite(ctx *genContext, v string) string {
|
|||
}
|
||||
|
||||
func (op uint256Op) genDecode(ctx *genContext) (string, string) {
|
||||
ctx.addImport("github.com/holiman/uint256")
|
||||
ctx.addImportPath("github.com/holiman/uint256")
|
||||
|
||||
var b bytes.Buffer
|
||||
resultV := ctx.temp()
|
||||
|
|
@ -732,7 +785,7 @@ func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstru
|
|||
// generateDecoder generates the DecodeRLP method on 'typ'.
|
||||
func generateDecoder(ctx *genContext, typ string, op op) []byte {
|
||||
ctx.resetTemp()
|
||||
ctx.addImport(pathOfPackageRLP)
|
||||
ctx.addImportPath(pathOfPackageRLP)
|
||||
|
||||
result, code := op.genDecode(ctx)
|
||||
var b bytes.Buffer
|
||||
|
|
@ -747,8 +800,8 @@ func generateDecoder(ctx *genContext, typ string, op op) []byte {
|
|||
// generateEncoder generates the EncodeRLP method on 'typ'.
|
||||
func generateEncoder(ctx *genContext, typ string, op op) []byte {
|
||||
ctx.resetTemp()
|
||||
ctx.addImport("io")
|
||||
ctx.addImport(pathOfPackageRLP)
|
||||
ctx.addImportPath("io")
|
||||
ctx.addImportPath(pathOfPackageRLP)
|
||||
|
||||
var b bytes.Buffer
|
||||
fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ)
|
||||
|
|
@ -783,7 +836,7 @@ func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]b
|
|||
var b bytes.Buffer
|
||||
fmt.Fprintf(&b, "package %s\n\n", pkg.Name())
|
||||
for _, imp := range ctx.importsList() {
|
||||
fmt.Fprintf(&b, "import %q\n", imp)
|
||||
fmt.Fprintf(&b, "import %s\n", imp)
|
||||
}
|
||||
if encoder {
|
||||
fmt.Fprintln(&b)
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ func init() {
|
|||
}
|
||||
}
|
||||
|
||||
var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256"}
|
||||
var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256", "pkgclash"}
|
||||
|
||||
func TestOutput(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
|
|
|
|||
13
rlp/rlpgen/testdata/pkgclash.in.txt
vendored
Normal file
13
rlp/rlpgen/testdata/pkgclash.in.txt
vendored
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
// -*- mode: go -*-
|
||||
|
||||
package test
|
||||
|
||||
import (
|
||||
eth1 "github.com/ethereum/go-ethereum/eth"
|
||||
eth2 "github.com/ethereum/go-ethereum/eth/protocols/eth"
|
||||
)
|
||||
|
||||
type Test struct {
|
||||
A eth1.MinerAPI
|
||||
B eth2.GetReceiptsPacket
|
||||
}
|
||||
82
rlp/rlpgen/testdata/pkgclash.out.txt
vendored
Normal file
82
rlp/rlpgen/testdata/pkgclash.out.txt
vendored
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package test
|
||||
|
||||
import "github.com/ethereum/go-ethereum/common"
|
||||
import "github.com/ethereum/go-ethereum/eth"
|
||||
import "github.com/ethereum/go-ethereum/rlp"
|
||||
import "io"
|
||||
import eth1 "github.com/ethereum/go-ethereum/eth/protocols/eth"
|
||||
|
||||
func (obj *Test) EncodeRLP(_w io.Writer) error {
|
||||
w := rlp.NewEncoderBuffer(_w)
|
||||
_tmp0 := w.List()
|
||||
_tmp1 := w.List()
|
||||
w.ListEnd(_tmp1)
|
||||
_tmp2 := w.List()
|
||||
w.WriteUint64(obj.B.RequestId)
|
||||
_tmp3 := w.List()
|
||||
for _, _tmp4 := range obj.B.GetReceiptsRequest {
|
||||
w.WriteBytes(_tmp4[:])
|
||||
}
|
||||
w.ListEnd(_tmp3)
|
||||
w.ListEnd(_tmp2)
|
||||
w.ListEnd(_tmp0)
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
|
||||
var _tmp0 Test
|
||||
{
|
||||
if _, err := dec.List(); err != nil {
|
||||
return err
|
||||
}
|
||||
// A:
|
||||
var _tmp1 eth.MinerAPI
|
||||
{
|
||||
if _, err := dec.List(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := dec.ListEnd(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_tmp0.A = _tmp1
|
||||
// B:
|
||||
var _tmp2 eth1.GetReceiptsPacket
|
||||
{
|
||||
if _, err := dec.List(); err != nil {
|
||||
return err
|
||||
}
|
||||
// RequestId:
|
||||
_tmp3, err := dec.Uint64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_tmp2.RequestId = _tmp3
|
||||
// GetReceiptsRequest:
|
||||
var _tmp4 []common.Hash
|
||||
if _, err := dec.List(); err != nil {
|
||||
return err
|
||||
}
|
||||
for dec.MoreDataInList() {
|
||||
var _tmp5 common.Hash
|
||||
if err := dec.ReadBytes(_tmp5[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
_tmp4 = append(_tmp4, _tmp5)
|
||||
}
|
||||
if err := dec.ListEnd(); err != nil {
|
||||
return err
|
||||
}
|
||||
_tmp2.GetReceiptsRequest = _tmp4
|
||||
if err := dec.ListEnd(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_tmp0.B = _tmp2
|
||||
if err := dec.ListEnd(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
*obj = _tmp0
|
||||
return nil
|
||||
}
|
||||
Loading…
Reference in a new issue