diff --git a/common/bitutil/bitutil.go b/common/bitutil/bitutil.go index db5b7df1b2..1ffc991a9b 100644 --- a/common/bitutil/bitutil.go +++ b/common/bitutil/bitutil.go @@ -76,7 +76,14 @@ func safeANDBytes(dst, a, b []byte) int { // ORBytes ors the bytes in a and b. The destination is assumed to have enough // space. Returns the number of bytes or'd. +// +// dst and x or y may overlap exactly or not at all, +// otherwise ORBytes may panic. func ORBytes(dst, a, b []byte) int { + n := min(len(a), len(b)) + if inexactOverlap(dst[:n], a[:n]) || inexactOverlap(dst[:n], b[:n]) { + panic("ORBytes: invalid overlap") + } return orBytes(dst, a, b) } @@ -132,3 +139,26 @@ func safeTestBytes(p []byte) bool { } return false } + +// anyOverlap reports whether x and y share memory at any (not necessarily +// corresponding) index. The memory beyond the slice length is ignored. +// from: https://github.com/golang/go/blob/4a3cef2036097d323b6cc0bbe90fc4d8c7588660/src/crypto/internal/fips140/alias/alias.go#L13-L17 +func anyOverlap(x, y []byte) bool { + return len(x) > 0 && len(y) > 0 && + uintptr(unsafe.Pointer(&x[0])) <= uintptr(unsafe.Pointer(&y[len(y)-1])) && + uintptr(unsafe.Pointer(&y[0])) <= uintptr(unsafe.Pointer(&x[len(x)-1])) +} + +// inexactOverlap reports whether x and y share memory at any non-corresponding +// index. The memory beyond the slice length is ignored. Note that x and y can +// have different lengths and still not have any inexact overlap. +// +// inexactOverlap can be used to implement the requirements of the crypto/cipher +// AEAD, Block, BlockMode and Stream interfaces. +// from: https://github.com/golang/go/blob/4a3cef2036097d323b6cc0bbe90fc4d8c7588660/src/crypto/internal/fips140/alias/alias.go#L25-L30 +func inexactOverlap(x, y []byte) bool { + if len(x) == 0 || len(y) == 0 || &x[0] == &y[0] { + return false + } + return anyOverlap(x, y) +} diff --git a/common/bitutil/or_asm.go b/common/bitutil/or_asm.go index 4337894299..6c6331e0e4 100644 --- a/common/bitutil/or_asm.go +++ b/common/bitutil/or_asm.go @@ -8,12 +8,8 @@ package bitutil func orBytes(dst, a, b []byte) int { - n := min(len(a), len(b)) - if n == 0 { - return 0 - } - orBytesASM(&dst[0], &a[0], &b[0], n) - return n + orBytesASM(&dst[0], &a[0], &b[0], len(a)) + return len(a) } //go:noescape