mirror of
https://github.com/ethereum/go-ethereum.git
synced 2026-02-26 15:47:21 +00:00
rlp/rlpgen: implement package renaming support (#31148)
This adds support for importing types from multiple identically-named packages. --------- Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
parent
038ff766ff
commit
5572f2ed22
4 changed files with 173 additions and 25 deletions
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
|
"github.com/ethereum/go-ethereum/rlp/internal/rlpstruct"
|
||||||
|
"golang.org/x/tools/go/packages"
|
||||||
)
|
)
|
||||||
|
|
||||||
// buildContext keeps the data needed for make*Op.
|
// buildContext keeps the data needed for make*Op.
|
||||||
|
|
@ -96,14 +97,20 @@ func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
|
||||||
// file and assigns unique names of temporary variables.
|
// file and assigns unique names of temporary variables.
|
||||||
type genContext struct {
|
type genContext struct {
|
||||||
inPackage *types.Package
|
inPackage *types.Package
|
||||||
imports map[string]struct{}
|
imports map[string]genImportPackage
|
||||||
tempCounter int
|
tempCounter int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type genImportPackage struct {
|
||||||
|
alias string
|
||||||
|
pkg *types.Package
|
||||||
|
}
|
||||||
|
|
||||||
func newGenContext(inPackage *types.Package) *genContext {
|
func newGenContext(inPackage *types.Package) *genContext {
|
||||||
return &genContext{
|
return &genContext{
|
||||||
inPackage: inPackage,
|
inPackage: inPackage,
|
||||||
imports: make(map[string]struct{}),
|
imports: make(map[string]genImportPackage),
|
||||||
|
tempCounter: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -117,32 +124,78 @@ func (ctx *genContext) resetTemp() {
|
||||||
ctx.tempCounter = 0
|
ctx.tempCounter = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctx *genContext) addImport(path string) {
|
func (ctx *genContext) addImportPath(path string) {
|
||||||
if path == ctx.inPackage.Path() {
|
pkg, err := ctx.loadPackage(path)
|
||||||
return // avoid importing the package that we're generating in.
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("can't load package %q: %v", path, err))
|
||||||
}
|
}
|
||||||
// TODO: renaming?
|
ctx.addImport(pkg)
|
||||||
ctx.imports[path] = struct{}{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// importsList returns all packages that need to be imported.
|
func (ctx *genContext) addImport(pkg *types.Package) string {
|
||||||
func (ctx *genContext) importsList() []string {
|
if pkg.Path() == ctx.inPackage.Path() {
|
||||||
imp := make([]string, 0, len(ctx.imports))
|
return "" // avoid importing the package that we're generating in
|
||||||
for k := range ctx.imports {
|
|
||||||
imp = append(imp, k)
|
|
||||||
}
|
}
|
||||||
sort.Strings(imp)
|
if p, exists := ctx.imports[pkg.Path()]; exists {
|
||||||
return imp
|
return p.alias
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
baseName = pkg.Name()
|
||||||
|
alias = baseName
|
||||||
|
counter = 1
|
||||||
|
)
|
||||||
|
// If the base name conflicts with an existing import, add a numeric suffix.
|
||||||
|
for ctx.hasAlias(alias) {
|
||||||
|
alias = fmt.Sprintf("%s%d", baseName, counter)
|
||||||
|
counter++
|
||||||
|
}
|
||||||
|
ctx.imports[pkg.Path()] = genImportPackage{alias, pkg}
|
||||||
|
return alias
|
||||||
}
|
}
|
||||||
|
|
||||||
// qualify is the types.Qualifier used for printing types.
|
// hasAlias checks if an alias is already in use
|
||||||
|
func (ctx *genContext) hasAlias(alias string) bool {
|
||||||
|
for _, p := range ctx.imports {
|
||||||
|
if p.alias == alias {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadPackage attempts to load package information
|
||||||
|
func (ctx *genContext) loadPackage(path string) (*types.Package, error) {
|
||||||
|
cfg := &packages.Config{Mode: packages.NeedName}
|
||||||
|
pkgs, err := packages.Load(cfg, path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(pkgs) == 0 {
|
||||||
|
return nil, fmt.Errorf("no package found for path %s", path)
|
||||||
|
}
|
||||||
|
return types.NewPackage(path, pkgs[0].Name), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// qualify is the types.Qualifier used for printing types
|
||||||
func (ctx *genContext) qualify(pkg *types.Package) string {
|
func (ctx *genContext) qualify(pkg *types.Package) string {
|
||||||
if pkg.Path() == ctx.inPackage.Path() {
|
if pkg.Path() == ctx.inPackage.Path() {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
ctx.addImport(pkg.Path())
|
return ctx.addImport(pkg)
|
||||||
// TODO: renaming?
|
}
|
||||||
return pkg.Name()
|
|
||||||
|
// importsList returns all packages that need to be imported
|
||||||
|
func (ctx *genContext) importsList() []string {
|
||||||
|
imp := make([]string, 0, len(ctx.imports))
|
||||||
|
for path, p := range ctx.imports {
|
||||||
|
if p.alias == p.pkg.Name() {
|
||||||
|
imp = append(imp, fmt.Sprintf("%q", path))
|
||||||
|
} else {
|
||||||
|
imp = append(imp, fmt.Sprintf("%s %q", p.alias, path))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Strings(imp)
|
||||||
|
return imp
|
||||||
}
|
}
|
||||||
|
|
||||||
type op interface {
|
type op interface {
|
||||||
|
|
@ -359,7 +412,7 @@ func (op uint256Op) genWrite(ctx *genContext, v string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (op uint256Op) genDecode(ctx *genContext) (string, string) {
|
func (op uint256Op) genDecode(ctx *genContext) (string, string) {
|
||||||
ctx.addImport("github.com/holiman/uint256")
|
ctx.addImportPath("github.com/holiman/uint256")
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
resultV := ctx.temp()
|
resultV := ctx.temp()
|
||||||
|
|
@ -732,7 +785,7 @@ func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstru
|
||||||
// generateDecoder generates the DecodeRLP method on 'typ'.
|
// generateDecoder generates the DecodeRLP method on 'typ'.
|
||||||
func generateDecoder(ctx *genContext, typ string, op op) []byte {
|
func generateDecoder(ctx *genContext, typ string, op op) []byte {
|
||||||
ctx.resetTemp()
|
ctx.resetTemp()
|
||||||
ctx.addImport(pathOfPackageRLP)
|
ctx.addImportPath(pathOfPackageRLP)
|
||||||
|
|
||||||
result, code := op.genDecode(ctx)
|
result, code := op.genDecode(ctx)
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
|
|
@ -747,8 +800,8 @@ func generateDecoder(ctx *genContext, typ string, op op) []byte {
|
||||||
// generateEncoder generates the EncodeRLP method on 'typ'.
|
// generateEncoder generates the EncodeRLP method on 'typ'.
|
||||||
func generateEncoder(ctx *genContext, typ string, op op) []byte {
|
func generateEncoder(ctx *genContext, typ string, op op) []byte {
|
||||||
ctx.resetTemp()
|
ctx.resetTemp()
|
||||||
ctx.addImport("io")
|
ctx.addImportPath("io")
|
||||||
ctx.addImport(pathOfPackageRLP)
|
ctx.addImportPath(pathOfPackageRLP)
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ)
|
fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ)
|
||||||
|
|
@ -783,7 +836,7 @@ func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]b
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
fmt.Fprintf(&b, "package %s\n\n", pkg.Name())
|
fmt.Fprintf(&b, "package %s\n\n", pkg.Name())
|
||||||
for _, imp := range ctx.importsList() {
|
for _, imp := range ctx.importsList() {
|
||||||
fmt.Fprintf(&b, "import %q\n", imp)
|
fmt.Fprintf(&b, "import %s\n", imp)
|
||||||
}
|
}
|
||||||
if encoder {
|
if encoder {
|
||||||
fmt.Fprintln(&b)
|
fmt.Fprintln(&b)
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ func init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256"}
|
var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256", "pkgclash"}
|
||||||
|
|
||||||
func TestOutput(t *testing.T) {
|
func TestOutput(t *testing.T) {
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
|
|
||||||
13
rlp/rlpgen/testdata/pkgclash.in.txt
vendored
Normal file
13
rlp/rlpgen/testdata/pkgclash.in.txt
vendored
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
// -*- mode: go -*-
|
||||||
|
|
||||||
|
package test
|
||||||
|
|
||||||
|
import (
|
||||||
|
eth1 "github.com/ethereum/go-ethereum/eth"
|
||||||
|
eth2 "github.com/ethereum/go-ethereum/eth/protocols/eth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Test struct {
|
||||||
|
A eth1.MinerAPI
|
||||||
|
B eth2.GetReceiptsPacket
|
||||||
|
}
|
||||||
82
rlp/rlpgen/testdata/pkgclash.out.txt
vendored
Normal file
82
rlp/rlpgen/testdata/pkgclash.out.txt
vendored
Normal file
|
|
@ -0,0 +1,82 @@
|
||||||
|
package test
|
||||||
|
|
||||||
|
import "github.com/ethereum/go-ethereum/common"
|
||||||
|
import "github.com/ethereum/go-ethereum/eth"
|
||||||
|
import "github.com/ethereum/go-ethereum/rlp"
|
||||||
|
import "io"
|
||||||
|
import eth1 "github.com/ethereum/go-ethereum/eth/protocols/eth"
|
||||||
|
|
||||||
|
func (obj *Test) EncodeRLP(_w io.Writer) error {
|
||||||
|
w := rlp.NewEncoderBuffer(_w)
|
||||||
|
_tmp0 := w.List()
|
||||||
|
_tmp1 := w.List()
|
||||||
|
w.ListEnd(_tmp1)
|
||||||
|
_tmp2 := w.List()
|
||||||
|
w.WriteUint64(obj.B.RequestId)
|
||||||
|
_tmp3 := w.List()
|
||||||
|
for _, _tmp4 := range obj.B.GetReceiptsRequest {
|
||||||
|
w.WriteBytes(_tmp4[:])
|
||||||
|
}
|
||||||
|
w.ListEnd(_tmp3)
|
||||||
|
w.ListEnd(_tmp2)
|
||||||
|
w.ListEnd(_tmp0)
|
||||||
|
return w.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (obj *Test) DecodeRLP(dec *rlp.Stream) error {
|
||||||
|
var _tmp0 Test
|
||||||
|
{
|
||||||
|
if _, err := dec.List(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// A:
|
||||||
|
var _tmp1 eth.MinerAPI
|
||||||
|
{
|
||||||
|
if _, err := dec.List(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := dec.ListEnd(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_tmp0.A = _tmp1
|
||||||
|
// B:
|
||||||
|
var _tmp2 eth1.GetReceiptsPacket
|
||||||
|
{
|
||||||
|
if _, err := dec.List(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// RequestId:
|
||||||
|
_tmp3, err := dec.Uint64()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_tmp2.RequestId = _tmp3
|
||||||
|
// GetReceiptsRequest:
|
||||||
|
var _tmp4 []common.Hash
|
||||||
|
if _, err := dec.List(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for dec.MoreDataInList() {
|
||||||
|
var _tmp5 common.Hash
|
||||||
|
if err := dec.ReadBytes(_tmp5[:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_tmp4 = append(_tmp4, _tmp5)
|
||||||
|
}
|
||||||
|
if err := dec.ListEnd(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_tmp2.GetReceiptsRequest = _tmp4
|
||||||
|
if err := dec.ListEnd(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_tmp0.B = _tmp2
|
||||||
|
if err := dec.ListEnd(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*obj = _tmp0
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue