forked from forks/go-ethereum
p2p: fix DiscReason encoding/decoding (#30855)
This fixes an issue where the disconnect message was not wrapped in a list. The specification requires it to be a list like any other message. In order to remain compatible with legacy geth versions, we now accept both encodings when parsing a disconnect message. --------- Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
parent
c7e740f40c
commit
c1c2507148
4 changed files with 44 additions and 20 deletions
25
p2p/peer.go
25
p2p/peer.go
|
|
@ -345,9 +345,7 @@ func (p *Peer) handle(msg Msg) error {
|
||||||
case msg.Code == discMsg:
|
case msg.Code == discMsg:
|
||||||
// This is the last message. We don't need to discard or
|
// This is the last message. We don't need to discard or
|
||||||
// check errors because, the connection will be closed after it.
|
// check errors because, the connection will be closed after it.
|
||||||
var m struct{ R DiscReason }
|
return decodeDisconnectMessage(msg.Payload)
|
||||||
rlp.Decode(msg.Payload, &m)
|
|
||||||
return m.R
|
|
||||||
case msg.Code < baseProtocolLength:
|
case msg.Code < baseProtocolLength:
|
||||||
// ignore other base protocol messages
|
// ignore other base protocol messages
|
||||||
return msg.Discard()
|
return msg.Discard()
|
||||||
|
|
@ -372,6 +370,27 @@ func (p *Peer) handle(msg Msg) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// decodeDisconnectMessage decodes the payload of discMsg.
|
||||||
|
func decodeDisconnectMessage(r io.Reader) (reason DiscReason) {
|
||||||
|
s := rlp.NewStream(r, 100)
|
||||||
|
k, _, err := s.Kind()
|
||||||
|
if err != nil {
|
||||||
|
return DiscInvalid
|
||||||
|
}
|
||||||
|
if k == rlp.List {
|
||||||
|
s.List()
|
||||||
|
err = s.Decode(&reason)
|
||||||
|
} else {
|
||||||
|
// Legacy path: some implementations, including geth, used to send the disconnect
|
||||||
|
// reason as a byte array by accident.
|
||||||
|
err = s.Decode(&reason)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
reason = DiscInvalid
|
||||||
|
}
|
||||||
|
return reason
|
||||||
|
}
|
||||||
|
|
||||||
func countMatchingProtocols(protocols []Protocol, caps []Cap) int {
|
func countMatchingProtocols(protocols []Protocol, caps []Cap) int {
|
||||||
n := 0
|
n := 0
|
||||||
for _, cap := range caps {
|
for _, cap := range caps {
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,8 @@ const (
|
||||||
DiscSelf
|
DiscSelf
|
||||||
DiscReadTimeout
|
DiscReadTimeout
|
||||||
DiscSubprotocolError = DiscReason(0x10)
|
DiscSubprotocolError = DiscReason(0x10)
|
||||||
|
|
||||||
|
DiscInvalid = 0xff
|
||||||
)
|
)
|
||||||
|
|
||||||
var discReasonToString = [...]string{
|
var discReasonToString = [...]string{
|
||||||
|
|
@ -86,10 +88,11 @@ var discReasonToString = [...]string{
|
||||||
DiscSelf: "connected to self",
|
DiscSelf: "connected to self",
|
||||||
DiscReadTimeout: "read timeout",
|
DiscReadTimeout: "read timeout",
|
||||||
DiscSubprotocolError: "subprotocol error",
|
DiscSubprotocolError: "subprotocol error",
|
||||||
|
DiscInvalid: "invalid disconnect reason",
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d DiscReason) String() string {
|
func (d DiscReason) String() string {
|
||||||
if len(discReasonToString) <= int(d) {
|
if len(discReasonToString) <= int(d) || discReasonToString[d] == "" {
|
||||||
return fmt.Sprintf("unknown disconnect reason %d", d)
|
return fmt.Sprintf("unknown disconnect reason %d", d)
|
||||||
}
|
}
|
||||||
return discReasonToString[d]
|
return discReasonToString[d]
|
||||||
|
|
|
||||||
|
|
@ -113,15 +113,14 @@ func (t *rlpxTransport) close(err error) {
|
||||||
// Tell the remote end why we're disconnecting if possible.
|
// Tell the remote end why we're disconnecting if possible.
|
||||||
// We only bother doing this if the underlying connection supports
|
// We only bother doing this if the underlying connection supports
|
||||||
// setting a timeout tough.
|
// setting a timeout tough.
|
||||||
if t.conn != nil {
|
if reason, ok := err.(DiscReason); ok && reason != DiscNetworkError {
|
||||||
if r, ok := err.(DiscReason); ok && r != DiscNetworkError {
|
// We do not use the WriteMsg func since we want a custom deadline
|
||||||
deadline := time.Now().Add(discWriteTimeout)
|
deadline := time.Now().Add(discWriteTimeout)
|
||||||
if err := t.conn.SetWriteDeadline(deadline); err == nil {
|
if err := t.conn.SetWriteDeadline(deadline); err == nil {
|
||||||
// Connection supports write deadline.
|
// Connection supports write deadline.
|
||||||
t.wbuf.Reset()
|
t.wbuf.Reset()
|
||||||
rlp.Encode(&t.wbuf, []DiscReason{r})
|
rlp.Encode(&t.wbuf, []any{reason})
|
||||||
t.conn.Write(discMsg, t.wbuf.Bytes())
|
t.conn.Write(discMsg, t.wbuf.Bytes())
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
t.conn.Close()
|
t.conn.Close()
|
||||||
|
|
@ -163,11 +162,8 @@ func readProtocolHandshake(rw MsgReader) (*protoHandshake, error) {
|
||||||
if msg.Code == discMsg {
|
if msg.Code == discMsg {
|
||||||
// Disconnect before protocol handshake is valid according to the
|
// Disconnect before protocol handshake is valid according to the
|
||||||
// spec and we send it ourself if the post-handshake checks fail.
|
// spec and we send it ourself if the post-handshake checks fail.
|
||||||
// We can't return the reason directly, though, because it is echoed
|
r := decodeDisconnectMessage(msg.Payload)
|
||||||
// back otherwise. Wrap it in a string instead.
|
return nil, r
|
||||||
var reason [1]DiscReason
|
|
||||||
rlp.Decode(msg.Payload, &reason)
|
|
||||||
return nil, reason[0]
|
|
||||||
}
|
}
|
||||||
if msg.Code != handshakeMsg {
|
if msg.Code != handshakeMsg {
|
||||||
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
|
return nil, fmt.Errorf("expected handshake, got %x", msg.Code)
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,7 @@ func TestProtocolHandshake(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ExpectMsg(rlpx, discMsg, []DiscReason{DiscQuitting}); err != nil {
|
if err := ExpectMsg(rlpx, discMsg, []any{DiscQuitting}); err != nil {
|
||||||
t.Errorf("error receiving disconnect: %v", err)
|
t.Errorf("error receiving disconnect: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
@ -112,7 +112,13 @@ func TestProtocolHandshakeErrors(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
code: discMsg,
|
code: discMsg,
|
||||||
msg: []DiscReason{DiscQuitting},
|
msg: []any{DiscQuitting},
|
||||||
|
err: DiscQuitting,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// legacy disconnect encoding as byte array
|
||||||
|
code: discMsg,
|
||||||
|
msg: []byte{byte(DiscQuitting)},
|
||||||
err: DiscQuitting,
|
err: DiscQuitting,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue