diff --git a/rlp/decoder.go b/rlp/decoder.go new file mode 100644 index 000000000..4f41ecd7f --- /dev/null +++ b/rlp/decoder.go @@ -0,0 +1,277 @@ +package rlp + +import ( + "errors" + "fmt" + "io" +) + +type Decoder struct { + buf *buf +} + +func NewDecoder(buf []byte) *Decoder { + return &Decoder{ + buf: newBuf(buf, 0), + } +} + +func (d *Decoder) String() string { + return fmt.Sprintf(`left=%x pos=%d`, d.buf.Bytes(), d.buf.off) +} + +func (d *Decoder) Consumed() []byte { + return d.buf.u[:d.buf.off] +} + +func (d *Decoder) Underlying() []byte { + return d.buf.Underlying() +} + +func (d *Decoder) Empty() bool { + return d.buf.empty() +} + +func (d *Decoder) Offset() int { + return d.buf.Offset() +} + +func (d *Decoder) Bytes() []byte { + return d.buf.Bytes() +} + +func (d *Decoder) ReadByte() (n byte, err error) { + return d.buf.ReadByte() +} + +func (d *Decoder) PeekByte() (n byte, err error) { + return d.buf.PeekByte() +} + +func (d *Decoder) Rebase() { + d.buf.u = d.Bytes() + d.buf.off = 0 +} +func (d *Decoder) Fork() *Decoder { + return &Decoder{ + buf: newBuf(d.buf.u, d.buf.off), + } +} + +func (d *Decoder) PeekToken() (Token, error) { + prefix, err := d.PeekByte() + if err != nil { + return TokenUnknown, err + } + return identifyToken(prefix), nil +} + +func (d *Decoder) ElemDec() (*Decoder, Token, error) { + a, t, err := d.Elem() + return NewDecoder(a), t, err +} + +func (d *Decoder) RawElemDec() (*Decoder, Token, error) { + a, t, err := d.RawElem() + return NewDecoder(a), t, err +} + +func (d *Decoder) RawElem() ([]byte, Token, error) { + w := d.buf + start := w.Offset() + // figure out what we are reading + prefix, err := w.ReadByte() + if err != nil { + return nil, TokenUnknown, err + } + token := identifyToken(prefix) + + var ( + sz int + lenSz int + ) + // switch on the token + switch token { + case TokenDecimal: + // in this case, the value is just the byte itself + case TokenShortList: + sz = int(token.Diff(prefix)) + _, err = nextFull(w, sz) + case TokenLongList: + lenSz = int(token.Diff(prefix)) + sz, err = nextBeInt(w, lenSz) + if err != nil { + return nil, token, err + } + _, err = nextFull(w, sz) + case TokenShortBlob: + sz := int(token.Diff(prefix)) + _, err = nextFull(w, sz) + case TokenLongBlob: + lenSz := int(token.Diff(prefix)) + sz, err = nextBeInt(w, lenSz) + if err != nil { + return nil, token, err + } + _, err = nextFull(w, sz) + default: + return nil, token, fmt.Errorf("%w: unknown token", ErrDecode) + } + stop := w.Offset() + //log.Printf("%x %s\n", buf, token) + if err != nil { + return nil, token, err + } + return w.Underlying()[start:stop], token, nil +} + +func (d *Decoder) Elem() ([]byte, Token, error) { + w := d.buf + // figure out what we are reading + prefix, err := w.ReadByte() + if err != nil { + return nil, TokenUnknown, err + } + token := identifyToken(prefix) + + var ( + buf []byte + sz int + lenSz int + ) + // switch on the token + switch token { + case TokenDecimal: + // in this case, the value is just the byte itself + buf = []byte{prefix} + case TokenShortList: + sz = int(token.Diff(prefix)) + buf, err = nextFull(w, sz) + case TokenLongList: + lenSz = int(token.Diff(prefix)) + sz, err = nextBeInt(w, lenSz) + if err != nil { + return nil, token, err + } + buf, err = nextFull(w, sz) + case TokenShortBlob: + sz := int(token.Diff(prefix)) + buf, err = nextFull(w, sz) + case TokenLongBlob: + lenSz := int(token.Diff(prefix)) + sz, err = nextBeInt(w, lenSz) + if err != nil { + return nil, token, err + } + buf, err = nextFull(w, sz) + default: + return nil, token, fmt.Errorf("%w: unknown token", ErrDecode) + } + //log.Printf("%x %s\n", buf, token) + if err != nil { + return nil, token, fmt.Errorf("read data: %w", err) + } + return buf, token, nil +} + +func ReadElem[T any](d *Decoder, fn func(*T, []byte) error, receiver *T) error { + buf, token, err := d.Elem() + if err != nil { + return err + } + switch token { + case TokenDecimal, + TokenShortBlob, + TokenLongBlob, + TokenShortList, + TokenLongList: + return fn(receiver, buf) + default: + return fmt.Errorf("%w: ReadElem found unexpected token", ErrDecode) + } +} + +func (d *Decoder) ForList(fn func(*Decoder) error) error { + // grab the list bytes + buf, token, err := d.Elem() + if err != nil { + return err + } + switch token { + case TokenShortList, TokenLongList: + dec := NewDecoder(buf) + for { + if dec.buf.Len() == 0 { + return nil + } + err := fn(dec) + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return err + } + // reset the byte + dec = NewDecoder(dec.Bytes()) + } + default: + return fmt.Errorf("%w: ForList on non-list", ErrDecode) + } +} + +type buf struct { + u []byte + off int +} + +func newBuf(u []byte, off int) *buf { + return &buf{u: u, off: off} +} + +func (b *buf) empty() bool { return len(b.u) <= b.off } + +func (b *buf) PeekByte() (n byte, err error) { + if len(b.u) <= b.off { + return 0, io.EOF + } + return b.u[b.off], nil +} +func (b *buf) ReadByte() (n byte, err error) { + if len(b.u) <= b.off { + return 0, io.EOF + } + b.off++ + return b.u[b.off-1], nil +} + +func (b *buf) Next(n int) (xs []byte) { + m := b.Len() + if n > m { + n = m + } + data := b.u[b.off : b.off+n] + b.off += n + return data +} + +func (b *buf) Offset() int { + return b.off +} + +func (b *buf) Bytes() []byte { + return b.u[b.off:] +} + +func (b *buf) String() string { + if b == nil { + // Special case, useful in debugging. + return "" + } + return string(b.u[b.off:]) +} + +func (b *buf) Len() int { return len(b.u) - b.off } + +func (b *buf) Underlying() []byte { + return b.u +} diff --git a/rlp/encoder.go b/rlp/encoder.go new file mode 100644 index 000000000..402378b18 --- /dev/null +++ b/rlp/encoder.go @@ -0,0 +1,189 @@ +package rlp + +import "golang.org/x/exp/constraints" + +type EncoderFunc = func(i *Encoder) *Encoder + +type Encoder struct { + buf []byte +} + +func NewEncoder(buf []byte) *Encoder { + return &Encoder{ + buf: buf, + } +} + +// Buffer returns the underlying buffer +func (e *Encoder) Buffer() []byte { + return e.buf +} + +func (e *Encoder) Byte(p byte) *Encoder { + e.buf = append(e.buf, p) + return e +} + +func (e *Encoder) Bytes(p []byte) *Encoder { + e.buf = append(e.buf, p...) + return e +} + +// Str will write a string correctly +func (e *Encoder) Str(str []byte) *Encoder { + if len(str) > 55 { + return e.LongString(str) + } + return e.ShortString(str) +} + +// String will assume your string is less than 56 bytes long, and do no validation as such +func (e *Encoder) ShortString(str []byte) *Encoder { + return e.Byte(TokenShortBlob.Plus(byte(len(str)))).Bytes(str) +} + +// String will assume your string is greater than 55 bytes long, and do no validation as such +func (e *Encoder) LongString(str []byte) *Encoder { + // write the indicator token + e.Byte(byte(TokenLongBlob)) + // write the integer, knowing that we appended n bytes + n := putUint(e, len(str)) + // so we knw the indicator token was n+1 bytes ago. + e.buf[len(e.buf)-(int(n)+1)] += n + // and now add the actual length + e.buf = append(e.buf, str...) + return e +} + +// List will attempt to write the list of encoder funcs to the buf +func (e *Encoder) List(items ...EncoderFunc) *Encoder { + return e.writeList(true, items...) +} + +// ShortList actually calls List +func (e *Encoder) ShortList(items ...EncoderFunc) *Encoder { + return e.writeList(true, items...) +} + +// LongList will assume that your list payload is more than 55 bytes long, and do no validation as such +func (e *Encoder) LongList(items ...EncoderFunc) *Encoder { + return e.writeList(false, items...) +} + +// writeList will first attempt to write a long list with the dat +// if validate is false, it will just format it like the length is above 55 +// if validate is true, it will format it like it is a shrot list +func (e *Encoder) writeList(validate bool, items ...EncoderFunc) *Encoder { + // write the indicator token + e = e.Byte(byte(TokenLongList)) + // now pad 8 bytes + e = e.Bytes(make([]byte, 8)) + // record the length before encoding items + startLength := len(e.buf) + // now write all the items + for _, v := range items { + e = v(e) + } + // the size is the difference in the lengths now + dataSize := len(e.buf) - startLength + if dataSize <= 55 && validate { + // oh it's actually a short string! awkward. let's set that then. + e.buf[startLength-8-1] = TokenShortList.Plus(byte(dataSize)) + // and then copy the data over + copy(e.buf[startLength-8:], e.buf[startLength:startLength+dataSize]) + // and now set the new size + e.buf = e.buf[:startLength+dataSize-8] + // we are done, return + return e + } + // ok, so it's a long string. + // create a new encoder centered at startLength - 8 + enc := NewEncoder(e.buf[startLength-8:]) + // now write using that encoder the size + n := putUint(enc, dataSize) + // and update the token, which we know is at startLength-8-1 + e.buf[startLength-8-1] += n + // the shift to perform now is 8 - n. + shift := int(8 - n) + // if there is a positive shift, then we must perform the shift + if shift > 0 { + // copy the data + copy(e.buf[startLength-shift:], e.buf[startLength:startLength+dataSize]) + // set the new length + e.buf = e.buf[:startLength-shift+dataSize] + } + return e +} + +func putUint[T constraints.Integer](e *Encoder, t T) (size byte) { + i := uint64(t) + switch { + case i < (1 << 8): + e.buf = append(e.buf, byte(i)) + return 1 + case i < (1 << 16): + e.buf = append(e.buf, + byte(i>>8), + byte(i), + ) + return 2 + case i < (1 << 24): + + e.buf = append(e.buf, + byte(i>>16), + byte(i>>8), + byte(i), + ) + return 3 + case i < (1 << 32): + e.buf = append(e.buf, + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + return 4 + case i < (1 << 40): + e.buf = append(e.buf, + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + return 5 + case i < (1 << 48): + e.buf = append(e.buf, + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + return 6 + case i < (1 << 56): + e.buf = append(e.buf, + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + return 7 + default: + e.buf = append(e.buf, + byte(i>>56), + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + return 8 + } +} diff --git a/rlp/parse.go b/rlp/parse.go index cbe59749a..449277f09 100644 --- a/rlp/parse.go +++ b/rlp/parse.go @@ -26,9 +26,10 @@ import ( ) var ( - ErrBase = fmt.Errorf("rlp") - ErrParse = fmt.Errorf("%w parse", ErrBase) - ErrDecode = fmt.Errorf("%w decode", ErrBase) + ErrBase = fmt.Errorf("rlp") + ErrParse = fmt.Errorf("%w parse", ErrBase) + ErrDecode = fmt.Errorf("%w decode", ErrBase) + ErrUnexpectedEOF = fmt.Errorf("%w EOF", ErrBase) ) func IsRLPError(err error) bool { return errors.Is(err, ErrBase) } @@ -36,8 +37,8 @@ func IsRLPError(err error) bool { return errors.Is(err, ErrBase) } // BeInt parses Big Endian representation of an integer from given payload at given position func BeInt(payload []byte, pos, length int) (int, error) { var r int - if pos+length >= len(payload) { - return 0, fmt.Errorf("%w: unexpected end of payload", ErrParse) + if pos+length > len(payload) { + return 0, ErrUnexpectedEOF } if length > 0 && payload[pos] == 0 { return 0, fmt.Errorf("%w: integer encoding for RLP must not have leading zeros: %x", ErrParse, payload[pos:pos+length]) diff --git a/rlp/readme.md b/rlp/readme.md new file mode 100644 index 000000000..74e9f96ee --- /dev/null +++ b/rlp/readme.md @@ -0,0 +1,11 @@ +## rlp + + +TERMINOLOGY: + +``` +RLP string = "Blob" // this is so we don't conflict with existing go name for String +RLP list = "List" // luckily we can keep using list name since go doesn't use it +RLP single byte number = "Decimal" // for numbers from 1-127. a special case +``` + diff --git a/rlp/types.go b/rlp/types.go new file mode 100644 index 000000000..f33bdfdc2 --- /dev/null +++ b/rlp/types.go @@ -0,0 +1,59 @@ +package rlp + +import ( + "fmt" + + "github.com/holiman/uint256" +) + +func Bytes(dst *[]byte, src []byte) error { + if len(*dst) < len(src) { + (*dst) = make([]byte, len(src)) + } + copy(*dst, src) + return nil +} +func BytesExact(dst *[]byte, src []byte) error { + if len(*dst) != len(src) { + return fmt.Errorf("%w: BytesExact no match", ErrDecode) + } + copy(*dst, src) + return nil +} + +func Uint256(dst *uint256.Int, src []byte) error { + if len(src) > 32 { + return fmt.Errorf("%w: uint256 must not be more than 32 bytes long, got %d", ErrParse, len(src)) + } + if len(src) > 0 && src[0] == 0 { + return fmt.Errorf("%w: integer encoding for RLP must not have leading zeros: %x", ErrParse, src) + } + dst.SetBytes(src) + return nil +} + +func Uint64(dst *uint64, src []byte) error { + var r uint64 + for _, b := range src { + r = (r << 8) | uint64(b) + } + (*dst) = r + return nil +} + +func IsEmpty(dst *bool, src []byte) error { + if len(src) == 0 { + (*dst) = true + } else { + (*dst) = false + } + return nil +} +func BlobLength(dst *int, src []byte) error { + (*dst) = len(src) + return nil +} + +func Skip(dst *int, src []byte) error { + return nil +} diff --git a/rlp/unmarshaler.go b/rlp/unmarshaler.go new file mode 100644 index 000000000..16c42a1f2 --- /dev/null +++ b/rlp/unmarshaler.go @@ -0,0 +1,191 @@ +package rlp + +import ( + "fmt" + "reflect" +) + +type Unmarshaler interface { + UnmarshalRLP(data []byte) error +} + +func Unmarshal(data []byte, val any) error { + buf := newBuf(data, 0) + return unmarshal(buf, val) +} + +func unmarshal(buf *buf, val any) error { + rv := reflect.ValueOf(val) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return fmt.Errorf("%w: v must be ptr", ErrDecode) + } + v := rv.Elem() + err := reflectAny(buf, v, rv) + if err != nil { + return fmt.Errorf("%w: %w", ErrDecode, err) + } + return nil +} + +func reflectAny(w *buf, v reflect.Value, rv reflect.Value) error { + if um, ok := rv.Interface().(Unmarshaler); ok { + return um.UnmarshalRLP(w.Bytes()) + } + // figure out what we are reading + prefix, err := w.ReadByte() + if err != nil { + return err + } + token := identifyToken(prefix) + // switch + switch token { + case TokenDecimal: + // in this case, the value is just the byte itself + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.SetInt(int64(prefix)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + v.SetUint(uint64(prefix)) + case reflect.Invalid: + // do nothing + default: + return fmt.Errorf("%w: decimal must be unmarshal into integer type", ErrDecode) + } + case TokenShortBlob: + sz := int(token.Diff(prefix)) + str, err := nextFull(w, sz) + if err != nil { + return err + } + return putBlob(str, v, rv) + case TokenLongBlob: + lenSz := int(token.Diff(prefix)) + sz, err := nextBeInt(w, lenSz) + if err != nil { + return err + } + str, err := nextFull(w, sz) + if err != nil { + return err + } + return putBlob(str, v, rv) + case TokenShortList: + sz := int(token.Diff(prefix)) + buf, err := nextFull(w, sz) + if err != nil { + return err + } + return reflectList(newBuf(buf, 0), v, rv) + case TokenLongList: + lenSz := int(token.Diff(prefix)) + sz, err := nextBeInt(w, lenSz) + if err != nil { + return err + } + buf, err := nextFull(w, sz) + if err != nil { + return err + } + return reflectList(newBuf(buf, 0), v, rv) + case TokenUnknown: + return fmt.Errorf("%w: unknown token", ErrDecode) + } + return nil +} + +func putBlob(w []byte, v reflect.Value, rv reflect.Value) error { + switch v.Kind() { + case reflect.String: + v.SetString(string(w)) + case reflect.Slice: + if v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("%w: need to use uint8 as underlying if want slice output from longstring", ErrDecode) + } + v.SetBytes(w) + case reflect.Array: + if v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("%w: need to use uint8 as underlying if want array output from longstring", ErrDecode) + } + reflect.Copy(v, reflect.ValueOf(w)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + val, err := BeInt(w, 0, len(w)) + if err != nil { + return err + } + v.SetInt(int64(val)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + val, err := BeInt(w, 0, len(w)) + if err != nil { + return err + } + v.SetUint(uint64(val)) + case reflect.Invalid: + // do nothing + return nil + } + return nil +} + +func reflectList(w *buf, v reflect.Value, rv reflect.Value) error { + switch v.Kind() { + case reflect.Invalid: + // do nothing + return nil + case reflect.Map: + rv1 := reflect.New(v.Type().Key()) + v1 := rv1.Elem() + err := reflectAny(w, v1, rv1) + if err != nil { + return err + } + rv2 := reflect.New(v.Type().Elem()) + v2 := rv2.Elem() + err = reflectAny(w, v2, rv2) + if err != nil { + return err + } + v.SetMapIndex(rv1, rv2) + case reflect.Struct: + for idx := 0; idx < v.NumField(); idx++ { + // Decode into element. + rv1 := v.Field(idx).Addr() + rt1 := v.Type().Field(idx) + v1 := rv1.Elem() + shouldSet := rt1.IsExported() + if shouldSet { + err := reflectAny(w, v1, rv1) + if err != nil { + return err + } + } + } + case reflect.Array, reflect.Slice: + idx := 0 + for { + if idx >= v.Cap() { + v.Grow(1) + } + if idx >= v.Len() { + v.SetLen(idx + 1) + } + if idx < v.Len() { + // Decode into element. + rv1 := v.Index(idx) + v1 := rv1.Elem() + err := reflectAny(w, v1, rv1) + if err != nil { + return err + } + } else { + // Ran out of fixed array: skip. + rv1 := reflect.Value{} + err := reflectAny(w, rv1, rv1) + if err != nil { + return err + } + } + idx++ + } + } + return nil +} diff --git a/rlp/unmarshaler_test.go b/rlp/unmarshaler_test.go new file mode 100644 index 000000000..bd0be6a87 --- /dev/null +++ b/rlp/unmarshaler_test.go @@ -0,0 +1,66 @@ +package rlp_test + +import ( + "testing" + + "github.com/ledgerwatch/erigon-lib/rlp" + "github.com/stretchr/testify/require" +) + +type plusOne int + +func (p *plusOne) UnmarshalRLP(data []byte) error { + var s int + err := rlp.Unmarshal(data, &s) + if err != nil { + return err + } + (*p) = plusOne(s + 1) + return nil +} + +func TestDecoder(t *testing.T) { + + type simple struct { + Key string + Value string + } + + t.Run("ShortString", func(t *testing.T) { + t.Run("ToString", func(t *testing.T) { + bts := []byte{0x83, 'd', 'o', 'g'} + var s string + err := rlp.Unmarshal(bts, &s) + require.NoError(t, err) + require.EqualValues(t, "dog", s) + }) + t.Run("ToBytes", func(t *testing.T) { + bts := []byte{0x83, 'd', 'o', 'g'} + var s []byte + err := rlp.Unmarshal(bts, &s) + require.NoError(t, err) + require.EqualValues(t, []byte("dog"), s) + }) + t.Run("ToInt", func(t *testing.T) { + bts := []byte{0x82, 0x04, 0x00} + var s int + err := rlp.Unmarshal(bts, &s) + require.NoError(t, err) + require.EqualValues(t, 1024, s) + }) + t.Run("ToIntUnmarshaler", func(t *testing.T) { + bts := []byte{0x82, 0x04, 0x00} + var s plusOne + err := rlp.Unmarshal(bts, &s) + require.NoError(t, err) + require.EqualValues(t, plusOne(1025), s) + }) + t.Run("ToSimpleStruct", func(t *testing.T) { + bts := []byte{0xc8, 0x83, 'c', 'a', 't', 0x83, 'd', 'o', 'g'} + var s simple + err := rlp.Unmarshal(bts, &s) + require.NoError(t, err) + require.EqualValues(t, simple{Key: "cat", Value: "dog"}, s) + }) + }) +} diff --git a/rlp/util.go b/rlp/util.go new file mode 100644 index 000000000..0219e1d39 --- /dev/null +++ b/rlp/util.go @@ -0,0 +1,84 @@ +package rlp + +type Token int32 + +func (T Token) String() string { + switch T { + case TokenDecimal: + return "decimal" + case TokenShortBlob: + return "short_blob" + case TokenLongBlob: + return "long_blob" + case TokenShortList: + return "short_list" + case TokenLongList: + return "long_list" + case TokenEOF: + return "eof" + case TokenUnknown: + return "unknown" + default: + return "nan" + } +} + +func (T Token) Plus(n byte) byte { + return byte(T) + n +} + +func (T Token) Diff(n byte) byte { + return n - byte(T) +} + +func (T Token) IsListType() bool { + return T == TokenLongList || T == TokenShortList +} + +func (T Token) IsBlobType() bool { + return T == TokenLongBlob || T == TokenShortBlob +} + +const ( + TokenDecimal Token = 0x00 + TokenShortBlob Token = 0x80 + TokenLongBlob Token = 0xb7 + TokenShortList Token = 0xc0 + TokenLongList Token = 0xf7 + + TokenUnknown Token = 0xff01 + TokenEOF Token = 0xdead +) + +func identifyToken(b byte) Token { + switch { + case b <= 127: + return TokenDecimal + case b >= 128 && b <= 183: + return TokenShortBlob + case b >= 184 && b <= 191: + return TokenLongBlob + case b >= 192 && b <= 247: + return TokenShortList + case b >= 248 && b <= 255: + return TokenLongList + } + return TokenUnknown +} + +// BeInt parses Big Endian representation of an integer from given payload at given position +func nextBeInt(w *buf, length int) (int, error) { + dat, err := nextFull(w, length) + if err != nil { + return 0, ErrUnexpectedEOF + } + return BeInt(dat, 0, length) +} + +func nextFull(dat *buf, size int) ([]byte, error) { + d := dat.Next(size) + if len(d) != size { + return nil, ErrUnexpectedEOF + } + return d, nil +} diff --git a/txpool/fetch.go b/txpool/fetch.go index b319f638b..355916f54 100644 --- a/txpool/fetch.go +++ b/txpool/fetch.go @@ -326,7 +326,7 @@ func (f *Fetch) handleInboundMessage(ctx context.Context, req *sentry.InboundMes switch req.Id { case sentry.MessageId_TRANSACTIONS_66: if err := f.threadSafeParsePooledTxn(func(parseContext *types2.TxParseContext) error { - if _, err := types2.ParseTransactions(req.Data, 0, parseContext, &txs, func(hash []byte) error { + if err := types2.DecodeTransactions(rlp.NewDecoder(req.Data), parseContext, &txs, func(hash []byte) error { known, err := f.pool.IdHashKnown(tx, hash) if err != nil { return err diff --git a/types/txn.go b/types/txn.go index a1e7f8224..db1324c16 100644 --- a/types/txn.go +++ b/types/txn.go @@ -21,7 +21,6 @@ import ( "encoding/binary" "errors" "fmt" - "hash" "io" "math/bits" "sort" @@ -29,13 +28,13 @@ import ( gokzg4844 "github.com/crate-crypto/go-kzg-4844" "github.com/holiman/uint256" "github.com/ledgerwatch/secp256k1" - "golang.org/x/crypto/sha3" "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon-lib/common/fixedgas" "github.com/ledgerwatch/erigon-lib/common/length" "github.com/ledgerwatch/erigon-lib/common/u256" "github.com/ledgerwatch/erigon-lib/crypto" + "github.com/ledgerwatch/erigon-lib/crypto/cryptopool" "github.com/ledgerwatch/erigon-lib/gointerfaces/types" "github.com/ledgerwatch/erigon-lib/rlp" ) @@ -47,8 +46,6 @@ type TxParseConfig struct { // TxParseContext is object that is required to parse transactions and turn transaction payload into TxSlot objects // usage of TxContext helps avoid extra memory allocations type TxParseContext struct { - Keccak2 hash.Hash - Keccak1 hash.Hash validateRlp func([]byte) error ChainID uint256.Int // Signature values R uint256.Int // Signature values @@ -72,8 +69,6 @@ func NewTxParseContext(chainID uint256.Int) *TxParseContext { } ctx := &TxParseContext{ withSender: true, - Keccak1: sha3.NewLegacyKeccak256(), - Keccak2: sha3.NewLegacyKeccak256(), } // behave as of London enabled @@ -146,7 +141,20 @@ func PeekTransactionType(serialized []byte) (byte, error) { // It also performs syntactic validation of the transactions. // wrappedWithBlobs means that for blob (type 3) transactions the full version with blobs/commitments/proofs is expected // (see https://eips.ethereum.org/EIPS/eip-4844#networking). +// +// [ ] || [ nonce, price, limit, to, value, data, y,r,s] +// 0x01 || [chain_id, nonce, price, limit, to, value, data, access_list, y,r,s] +// 0x02 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, y,r,s] +// 0x03 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s] +// 0x03 ||[[chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s], blobs, commitments, proofs] func (ctx *TxParseContext) ParseTransaction(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (p int, err error) { + p, err = ctx.parseTransaction(payload, pos, slot, sender, hasEnvelope, wrappedWithBlobs, validateHash) + if err != nil { + return p, fmt.Errorf("%w: %w", ErrParseTxn, err) + } + return p, nil +} +func (ctx *TxParseContext) parseTransaction(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (p int, err error) { if len(payload) == 0 { return 0, fmt.Errorf("%w: empty rlp", ErrParseTxn) } @@ -287,15 +295,21 @@ func (ctx *TxParseContext) parseTransactionBody(payload []byte, pos, p0 int, slo p = p0 legacy := slot.Type == LegacyTxType + k1 := cryptopool.GetLegacyKeccak256() + k2 := cryptopool.GetLegacyKeccak256() + defer cryptopool.ReturnLegacyKeccak256(k1) + defer cryptopool.ReturnLegacyKeccak256(k2) + + k1.Reset() + k2.Reset() + // Compute transaction hash - ctx.Keccak1.Reset() - ctx.Keccak2.Reset() if !legacy { typeByte := []byte{slot.Type} - if _, err = ctx.Keccak1.Write(typeByte); err != nil { + if _, err = k1.Write(typeByte); err != nil { return 0, fmt.Errorf("%w: computing IdHash (hashing type Prefix): %s", ErrParseTxn, err) //nolint } - if _, err = ctx.Keccak2.Write(typeByte); err != nil { + if _, err = k2.Write(typeByte); err != nil { return 0, fmt.Errorf("%w: computing signHash (hashing type Prefix): %s", ErrParseTxn, err) //nolint } dataPos, dataLen, err := rlp.List(payload, p) @@ -303,7 +317,7 @@ func (ctx *TxParseContext) parseTransactionBody(payload []byte, pos, p0 int, slo return 0, fmt.Errorf("%w: envelope Prefix: %s", ErrParseTxn, err) //nolint } // Hash the content of envelope, not the full payload - if _, err = ctx.Keccak1.Write(payload[p : dataPos+dataLen]); err != nil { + if _, err = k1.Write(payload[p : dataPos+dataLen]); err != nil { return 0, fmt.Errorf("%w: computing IdHash (hashing the envelope): %s", ErrParseTxn, err) //nolint } p = dataPos @@ -520,12 +534,12 @@ func (ctx *TxParseContext) parseTransactionBody(payload []byte, pos, p0 int, slo // For legacy transactions, hash the full payload if legacy { - if _, err = ctx.Keccak1.Write(payload[pos:p]); err != nil { + if _, err = k1.Write(payload[pos:p]); err != nil { return 0, fmt.Errorf("%w: computing IdHash: %s", ErrParseTxn, err) //nolint } } //ctx.keccak1.Sum(slot.IdHash[:0]) - _, _ = ctx.Keccak1.(io.Reader).Read(slot.IDHash[:32]) + _, _ = k1.(io.Reader).Read(slot.IDHash[:32]) if validateHash != nil { if err := validateHash(slot.IDHash[:32]); err != nil { return p, err @@ -544,25 +558,25 @@ func (ctx *TxParseContext) parseTransactionBody(payload []byte, pos, p0 int, slo // Write len Prefix to the Sighash if sigHashLen < 56 { ctx.buf[0] = byte(sigHashLen) + 192 - if _, err := ctx.Keccak2.Write(ctx.buf[:1]); err != nil { + if _, err := k2.Write(ctx.buf[:1]); err != nil { return 0, fmt.Errorf("%w: computing signHash (hashing len Prefix): %s", ErrParseTxn, err) //nolint } } else { beLen := common.BitLenToByteLen(bits.Len(sigHashLen)) binary.BigEndian.PutUint64(ctx.buf[1:], uint64(sigHashLen)) ctx.buf[8-beLen] = byte(beLen) + 247 - if _, err := ctx.Keccak2.Write(ctx.buf[8-beLen : 9]); err != nil { + if _, err := k2.Write(ctx.buf[8-beLen : 9]); err != nil { return 0, fmt.Errorf("%w: computing signHash (hashing len Prefix): %s", ErrParseTxn, err) //nolint } } - if _, err = ctx.Keccak2.Write(payload[sigHashPos:sigHashEnd]); err != nil { + if _, err = k2.Write(payload[sigHashPos:sigHashEnd]); err != nil { return 0, fmt.Errorf("%w: computing signHash: %s", ErrParseTxn, err) //nolint } if legacy { if chainIDLen > 0 { if chainIDBits <= 7 { ctx.buf[0] = byte(ctx.ChainID.Uint64()) - if _, err := ctx.Keccak2.Write(ctx.buf[:1]); err != nil { + if _, err := k2.Write(ctx.buf[:1]); err != nil { return 0, fmt.Errorf("%w: computing signHash (hashing legacy chainId): %s", ErrParseTxn, err) //nolint } } else { @@ -571,20 +585,20 @@ func (ctx *TxParseContext) parseTransactionBody(payload []byte, pos, p0 int, slo binary.BigEndian.PutUint64(ctx.buf[17:25], ctx.ChainID[1]) binary.BigEndian.PutUint64(ctx.buf[25:33], ctx.ChainID[0]) ctx.buf[32-chainIDLen] = 128 + byte(chainIDLen) - if _, err = ctx.Keccak2.Write(ctx.buf[32-chainIDLen : 33]); err != nil { + if _, err = k2.Write(ctx.buf[32-chainIDLen : 33]); err != nil { return 0, fmt.Errorf("%w: computing signHash (hashing legacy chainId): %s", ErrParseTxn, err) //nolint } } // Encode two zeros ctx.buf[0] = 128 ctx.buf[1] = 128 - if _, err := ctx.Keccak2.Write(ctx.buf[:2]); err != nil { + if _, err := k2.Write(ctx.buf[:2]); err != nil { return 0, fmt.Errorf("%w: computing signHash (hashing zeros after legacy chainId): %s", ErrParseTxn, err) //nolint } } } // Squeeze Sighash - _, _ = ctx.Keccak2.(io.Reader).Read(ctx.Sighash[:32]) + _, _ = k2.(io.Reader).Read(ctx.Sighash[:32]) //ctx.keccak2.Sum(ctx.Sighash[:0]) binary.BigEndian.PutUint64(ctx.Sig[0:8], ctx.R[3]) binary.BigEndian.PutUint64(ctx.Sig[8:16], ctx.R[2]) @@ -600,13 +614,13 @@ func (ctx *TxParseContext) parseTransactionBody(payload []byte, pos, p0 int, slo return 0, fmt.Errorf("%w: recovering sender from signature: %s", ErrParseTxn, err) //nolint } //apply keccak to the public key - ctx.Keccak2.Reset() - if _, err = ctx.Keccak2.Write(ctx.buf[1:65]); err != nil { + k2.Reset() + if _, err = k2.Write(ctx.buf[1:65]); err != nil { return 0, fmt.Errorf("%w: computing sender from public key: %s", ErrParseTxn, err) //nolint } // squeeze the hash of the public key //ctx.keccak2.Sum(ctx.buf[:0]) - _, _ = ctx.Keccak2.(io.Reader).Read(ctx.buf[:32]) + _, _ = k2.(io.Reader).Read(ctx.buf[:32]) //take last 20 bytes as address copy(sender, ctx.buf[12:32]) diff --git a/types/txn_decode.go b/types/txn_decode.go new file mode 100644 index 000000000..7f1c6c2da --- /dev/null +++ b/types/txn_decode.go @@ -0,0 +1,504 @@ +package types + +import ( + "encoding/binary" + "fmt" + "io" + "math/bits" + + gokzg4844 "github.com/crate-crypto/go-kzg-4844" + "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon-lib/common/u256" + "github.com/ledgerwatch/erigon-lib/crypto" + "github.com/ledgerwatch/erigon-lib/crypto/cryptopool" + "github.com/ledgerwatch/erigon-lib/rlp" + "github.com/ledgerwatch/secp256k1" +) + +// DecodeTransaction extracts all the information from the transactions's payload (RLP) necessary to build TxSlot. +// It also performs syntactic validation of the transactions. +// wrappedWithBlobs means that for blob (type 3) transactions the full version with blobs/commitments/proofs is expected +// (see https://eips.ethereum.org/EIPS/eip-4844#networking). +// +// // [ ] || [ nonce, price, limit, to, value, data, y,r,s] +// // 0x01 || [chain_id, nonce, price, limit, to, value, data, access_list, y,r,s] +// // 0x02 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, y,r,s] +// // 0x03 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s] +// // 0x03 ||[[chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s], blobs, commitments, proofs] +func (ctx *TxParseContext) DecodeTransaction(decoder *rlp.Decoder, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (err error) { + err = ctx.decodeTransaction(decoder, slot, sender, hasEnvelope, wrappedWithBlobs, validateHash) + if err != nil { + return fmt.Errorf("%w: %w", ErrParseTxn, err) + } + return nil +} + +func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (err error) { + od := decoder + if decoder.Empty() { + return fmt.Errorf("empty rlp") + } + if ctx.withSender && len(sender) != 20 { + return fmt.Errorf("expect sender buffer of len 20") + } + // start classification + token, err := decoder.PeekToken() + if err != nil { + return err + } + // means that this is non-enveloped non-legacy transaction + if token == rlp.TokenDecimal { + if hasEnvelope { + return fmt.Errorf("expected envelope in the payload, got %s", token) + } + } + + var ( + dec *rlp.Decoder + ) + var ( + parent *rlp.Decoder // the parent should contain in its underlying buffer the rlp used for txn hash creation sans txn type + bodyDecoder *rlp.Decoder // the bodyDecoder should be an rlp decoder primed at the top of the list body for txn + ) + + switch { + default: + return fmt.Errorf("expected list or blob token, got %s", token) + case token.IsListType(): // Legacy transactions have list Prefix, + // enter the list + parent = decoder + bodyDecoder, _, err = decoder.ElemDec() + if err != nil { + return fmt.Errorf("size prefix: %w", err) //nolint + } + slot.Rlp = append([]byte{}, decoder.Consumed()...) + slot.Type = LegacyTxType + case token.IsBlobType() || token == rlp.TokenDecimal: // EIP-2718 transactions have string Prefix + // if is blob type, it means that its an envelope, so we need to get out of that + if token.IsBlobType() { + decoder, _, err = decoder.ElemDec() + if err != nil { + return fmt.Errorf("size prefix: %w", err) //nolint + } + if decoder.Empty() { + return fmt.Errorf("transaction must be either 1 list or 1 string") + } + } + slot.Rlp = append([]byte{}, decoder.Bytes()...) + // at this point, the next byte is the type + slot.Type, err = decoder.ReadByte() + if err != nil { + return fmt.Errorf("couldnt read txn type: %w", err) + } + if slot.Type > BlobTxType { + return fmt.Errorf("unknown transaction type: %d", slot.Type) + } + // from here to the end of the element, if this is not a blob tx type with blobs, is the parent + parent, _, err = decoder.RawElemDec() + if err != nil { + return fmt.Errorf("extract txn body parent: %w", err) //nolint + } + // now enter the list, since that is what we are in front of now. + bodyDecoder, _, err = parent.ElemDec() + if err != nil { + return fmt.Errorf("extract txn body: %w", err) //nolint + } + if slot.Type == BlobTxType { + if wrappedWithBlobs { + dec = bodyDecoder + // if its a blob transaction and wrapped with blobs, we actually need to enter a nested list + // in this case, "decoder" was an iterator for the array of [ [txbody...], blobs, commitments, proofs] + // so our bodyDecoder is actually a pointer to the header of [txbody...]. + // we can extract the raw elem out of this in order to get the parent decoder + parent, _, err = bodyDecoder.RawElemDec() + if err != nil { + return fmt.Errorf("wrapped blob tx body parent: %w", err) //nolint + } + // and then the body is actually just the parent decoder read once, since we are sitting at the top of the [txbody...] header + bodyDecoder, _, err = parent.ElemDec() + if err != nil { + return fmt.Errorf("wrapped blob tx body: %w", err) //nolint + } + } + } + } + err = ctx.decodeTransactionBody(bodyDecoder, parent, slot, sender, validateHash) + if err != nil { + return fmt.Errorf("txn body: %w", err) + } + + // so its a blob transaction and we need to do the extra stuff, otherwise we are done + if slot.Type == BlobTxType && wrappedWithBlobs { + if err := ctx.decodeBlobs(dec, slot); err != nil { + return fmt.Errorf("decode blobs: %w", err) + } + if err := ctx.decodeCommitments(dec, slot); err != nil { + return fmt.Errorf("decode commitments: %w", err) + } + if err := ctx.decodeProofs(dec, slot); err != nil { + return fmt.Errorf("decode proofs: %w", err) + } + if len(slot.Blobs) != len(slot.Commitments) { + return fmt.Errorf("blob count != commitment count") + } + if len(slot.Commitments) != len(slot.Proofs) { + return fmt.Errorf("commitment count != proof count") + } + if len(slot.BlobHashes) != len(slot.Blobs) { + return fmt.Errorf("blob count != blob hash count") + } + } + slot.Size = uint32(od.Offset()) + return err +} + +func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.Decoder, slot *TxSlot, sender []byte, validateHash func([]byte) error) (err error) { + legacy := slot.Type == LegacyTxType + + k1 := cryptopool.GetLegacyKeccak256() + k2 := cryptopool.GetLegacyKeccak256() + defer cryptopool.ReturnLegacyKeccak256(k1) + defer cryptopool.ReturnLegacyKeccak256(k2) + + k1.Reset() + k2.Reset() + + // for computing tx hash + if !legacy { + typeByte := []byte{slot.Type} + if _, err := k1.Write(typeByte); err != nil { + return fmt.Errorf("compute idHash: %w", err) + } + if _, err := k2.Write(typeByte); err != nil { + return fmt.Errorf("compute Hash: %w", err) + } + if _, err := k1.Write(parent.Underlying()); err != nil { + return fmt.Errorf("compute idHash: %w", err) + } + } + + if ctx.validateRlp != nil { + if err := ctx.validateRlp(slot.Rlp); err != nil { + return fmt.Errorf("validate rlp: %w", err) + } + } + + // signing hash data starts here + sigHashPos := dec.Offset() + if !legacy { + err = rlp.ReadElem(dec, rlp.Uint256, &ctx.ChainID) + if err != nil { + return fmt.Errorf("bad chainId: %w", err) //nolint + } + if ctx.ChainID.IsZero() { // zero indicates that the chain ID was not specified in the tx. + if ctx.chainIDRequired { + return fmt.Errorf("chainID is required") + } + ctx.ChainID.Set(&ctx.cfg.ChainID) + } + if !ctx.ChainID.Eq(&ctx.cfg.ChainID) { + return fmt.Errorf("%s, %d (expected %d)", "invalid chainID", ctx.ChainID.Uint64(), ctx.cfg.ChainID.Uint64()) + } + } + + // Next follows the nonce, which we need to parse + err = rlp.ReadElem(dec, rlp.Uint64, &slot.Nonce) + if err != nil { + return fmt.Errorf("nonce: %s", err) //nolint + } + // Next follows gas price or tip + err = rlp.ReadElem(dec, rlp.Uint256, &slot.Tip) + if err != nil { + return fmt.Errorf("tip: %s", err) //nolint + } + // Next follows feeCap, but only for dynamic fee transactions, for legacy transaction, it is + // equal to tip + if slot.Type < DynamicFeeTxType { + slot.FeeCap = slot.Tip + } else { + err = rlp.ReadElem(dec, rlp.Uint256, &slot.FeeCap) + if err != nil { + return fmt.Errorf("feeCap: %s", err) //nolint + } + } + // gas limit + err = rlp.ReadElem(dec, rlp.Uint64, &slot.Gas) + if err != nil { + return fmt.Errorf("gas: %s", err) //nolint + } + // recipient + err = rlp.ReadElem(dec, rlp.IsEmpty, &slot.Creation) + if err != nil { + return fmt.Errorf("value: %s", err) //nolint + } + // Next follows value + err = rlp.ReadElem(dec, rlp.Uint256, &slot.Value) + if err != nil { + return fmt.Errorf("value: %s", err) //nolint + } + // Next goes data, but we are only interesting in its length + err = rlp.ReadElem(dec, func(i *int, b []byte) error { + slot.DataLen = len(b) + for _, byt := range b { + if byt != 0 { + slot.DataNonZeroLen++ + } + } + return nil + }, nil) + if err != nil { + return fmt.Errorf("data len: %s", err) //nolint + } + // Zero and non-zero bytes are priced differently + slot.DataNonZeroLen = 0 + // Next follows access list for non-legacy transactions, we are only interesting in number of addresses and storage keys + if !legacy { + err = dec.ForList(func(ld *rlp.Decoder) error { + slot.AlAddrCount++ + err := rlp.ReadElem(ld, rlp.Skip, nil) + if err != nil { + return fmt.Errorf("elem: %w", err) + } + err = ld.ForList(func(sk *rlp.Decoder) error { + slot.AlStorCount++ + err := rlp.ReadElem(sk, rlp.Skip, nil) + if err != nil { + return fmt.Errorf("elem: %w", err) + } + return nil + }) + if err != nil { + return fmt.Errorf("iterate: %w", err) + } + return nil + }) + if err != nil { + return fmt.Errorf("iterate: %w", err) + } + } + + if slot.Type == BlobTxType { + err = rlp.ReadElem(dec, rlp.Uint256, &slot.BlobFeeCap) + if err != nil { + return fmt.Errorf("blob fee cap: %s", err) //nolint + } + dec.ForList(func(d *rlp.Decoder) error { + var blob common.Hash + blobSlice := blob[:] + err := rlp.ReadElem(d, rlp.BytesExact, &blobSlice) + if err != nil { + return err + } + slot.BlobHashes = append(slot.BlobHashes, blob) + return nil + }) + } + // This is where the data for Sighash ends + sigHashEnd := dec.Offset() + sigHashLen := uint(sigHashEnd - sigHashPos) + + // Next follows V of the signature + var vByte byte + var chainIDBits, chainIDLen int + if legacy { + err = rlp.ReadElem(dec, rlp.Uint256, &ctx.V) + if err != nil { + return fmt.Errorf("V: %s", err) //nolint + } + ctx.IsProtected = ctx.V.Eq(u256.N27) || ctx.V.Eq(u256.N28) + // Compute chainId from V + if ctx.IsProtected { + // Do not add chain id and two extra zeros + vByte = byte(ctx.V.Uint64() - 27) + ctx.ChainID.Set(&ctx.cfg.ChainID) + } else { + ctx.ChainID.Sub(&ctx.V, u256.N35) + ctx.ChainID.Rsh(&ctx.ChainID, 1) + if !ctx.ChainID.Eq(&ctx.cfg.ChainID) { + return fmt.Errorf("%s, %d (expected %d)", "invalid chainID", ctx.ChainID.Uint64(), ctx.cfg.ChainID.Uint64()) + } + + chainIDBits = ctx.ChainID.BitLen() + if chainIDBits <= 7 { + chainIDLen = 1 + } else { + chainIDLen = common.BitLenToByteLen(chainIDBits) // It is always < 56 bytes + sigHashLen++ // For chainId len Prefix + } + sigHashLen += uint(chainIDLen) // For chainId + sigHashLen += 2 // For two extra zeros + ctx.DeriveChainID.Sub(&ctx.V, &ctx.ChainIDMul) + vByte = byte(ctx.DeriveChainID.Sub(&ctx.DeriveChainID, u256.N8).Uint64() - 27) + } + } else { + var v uint64 + err = rlp.ReadElem(dec, rlp.Uint64, &v) + if err != nil { + return fmt.Errorf("V: %s", err) //nolint + } + if v > 1 { + return fmt.Errorf("V is loo large: %d", v) + } + vByte = byte(v) + ctx.IsProtected = true + } + + // Next follows R of the signature + err = rlp.ReadElem(dec, rlp.Uint256, &ctx.R) + if err != nil { + return fmt.Errorf("R: %s", err) //nolint + } + // New follows S of the signature + err = rlp.ReadElem(dec, rlp.Uint256, &ctx.S) + if err != nil { + return fmt.Errorf("S: %s", err) //nolint + } + // For legacy transactions, just hash the full payload + if legacy { + if _, err = k1.Write(parent.Consumed()); err != nil { + return fmt.Errorf("computing IdHash: %s", err) //nolint + } + } + + // write the hash to IdHash buffer + //k1.Sum(slot.IDHash[:0]) + _, _ = k1.(io.Reader).Read(slot.IDHash[:32]) + + if validateHash != nil { + if err := validateHash(slot.IDHash[:32]); err != nil { + return err + } + } + + if !ctx.withSender { + return nil + } + + if !crypto.TransactionSignatureIsValid(vByte, &ctx.R, &ctx.S, ctx.allowPreEip2s && legacy) { + return fmt.Errorf("invalid v, r, s: %d, %s, %s", vByte, &ctx.R, &ctx.S) + } + + // Computing sigHash (hash used to recover sender from the signature) + // Write len Prefix to the Sighash + if sigHashLen < 56 { + ctx.buf[0] = byte(sigHashLen) + 192 + if _, err := k2.Write(ctx.buf[:1]); err != nil { + return fmt.Errorf("computing signHash (hashing len Prefix1): %s", err) //nolint + } + } else { + beLen := common.BitLenToByteLen(bits.Len(sigHashLen)) + binary.BigEndian.PutUint64(ctx.buf[1:], uint64(sigHashLen)) + ctx.buf[8-beLen] = byte(beLen) + 247 + if _, err := k2.Write(ctx.buf[8-beLen : 9]); err != nil { + return fmt.Errorf("computing signHash (hashing len Prefix2): %s", err) //nolint + } + } + if _, err = k2.Write(dec.Underlying()[sigHashPos:sigHashEnd]); err != nil { + return fmt.Errorf("computing signHash: %s", err) //nolint + } + if legacy { + if chainIDLen > 0 { + if chainIDBits <= 7 { + ctx.buf[0] = byte(ctx.ChainID.Uint64()) + if _, err := k2.Write(ctx.buf[:1]); err != nil { + return fmt.Errorf("computing signHash (hashing legacy chainId): %s", err) //nolint + } + } else { + binary.BigEndian.PutUint64(ctx.buf[1:9], ctx.ChainID[3]) + binary.BigEndian.PutUint64(ctx.buf[9:17], ctx.ChainID[2]) + binary.BigEndian.PutUint64(ctx.buf[17:25], ctx.ChainID[1]) + binary.BigEndian.PutUint64(ctx.buf[25:33], ctx.ChainID[0]) + ctx.buf[32-chainIDLen] = 128 + byte(chainIDLen) + if _, err = k2.Write(ctx.buf[32-chainIDLen : 33]); err != nil { + return fmt.Errorf("computing signHash (hashing legacy chainId): %s", err) //nolint + } + } + // Encode two zeros + ctx.buf[0] = 128 + ctx.buf[1] = 128 + if _, err := k2.Write(ctx.buf[:2]); err != nil { + return fmt.Errorf("computing signHash (hashing zeros after legacy chainId): %s", err) //nolint + } + } + } + // Squeeze Sighash + _, _ = k2.(io.Reader).Read(ctx.Sighash[:32]) + //k2.Sum(ctx.Sighash[:0]) + binary.BigEndian.PutUint64(ctx.Sig[0:8], ctx.R[3]) + binary.BigEndian.PutUint64(ctx.Sig[8:16], ctx.R[2]) + binary.BigEndian.PutUint64(ctx.Sig[16:24], ctx.R[1]) + binary.BigEndian.PutUint64(ctx.Sig[24:32], ctx.R[0]) + binary.BigEndian.PutUint64(ctx.Sig[32:40], ctx.S[3]) + binary.BigEndian.PutUint64(ctx.Sig[40:48], ctx.S[2]) + binary.BigEndian.PutUint64(ctx.Sig[48:56], ctx.S[1]) + binary.BigEndian.PutUint64(ctx.Sig[56:64], ctx.S[0]) + ctx.Sig[64] = vByte + // recover sender + if _, err = secp256k1.RecoverPubkeyWithContext(secp256k1.DefaultContext, ctx.Sighash[:], ctx.Sig[:], ctx.buf[:0]); err != nil { + return fmt.Errorf("recovering sender from signature: %s", err) //nolint + } + //apply keccak to the public key + k2.Reset() + if _, err = k2.Write(ctx.buf[1:65]); err != nil { + return fmt.Errorf("computing sender from public key: %s", err) //nolint + } + // squeeze the hash of the public key + //k2.Sum(ctx.buf[:0]) + _, _ = k2.(io.Reader).Read(ctx.buf[:32]) + //take last 20 bytes as address + copy(sender, ctx.buf[12:32]) + + return nil +} + +func (ctx *TxParseContext) decodeCommitments(dec *rlp.Decoder, slot *TxSlot) (err error) { + err = dec.ForList(func(d *rlp.Decoder) error { + var blob gokzg4844.KZGCommitment + blobSlice := blob[:] + err := rlp.ReadElem(d, rlp.BytesExact, &blobSlice) + if err != nil { + return err + } + slot.Commitments = append(slot.Commitments, blob) + return nil + }) + + if err != nil { + return err + } + return nil +} + +func (ctx *TxParseContext) decodeProofs(dec *rlp.Decoder, slot *TxSlot) (err error) { + err = dec.ForList(func(d *rlp.Decoder) error { + var blob gokzg4844.KZGProof + blobSlice := blob[:] + err := rlp.ReadElem(d, rlp.BytesExact, &blobSlice) + if err != nil { + return err + } + slot.Proofs = append(slot.Proofs, blob) + return nil + }) + if err != nil { + return err + } + return nil +} + +func (ctx *TxParseContext) decodeBlobs(dec *rlp.Decoder, slot *TxSlot) (err error) { + err = dec.ForList(func(d *rlp.Decoder) error { + var blob []byte + err := rlp.ReadElem(d, rlp.Bytes, &blob) + if err != nil { + return err + } + slot.Blobs = append(slot.Blobs, blob) + return nil + }) + if err != nil { + return err + } + + return nil +} diff --git a/types/txn_packets.go b/types/txn_packets.go index 3f7ed58bd..630341752 100644 --- a/types/txn_packets.go +++ b/types/txn_packets.go @@ -191,7 +191,6 @@ func ParseTransactions(payload []byte, pos int, ctx *TxParseContext, txSlots *Tx } return pos, nil } - func ParsePooledTransactions66(payload []byte, pos int, ctx *TxParseContext, txSlots *TxSlots, validateHash func([]byte) error) (requestID uint64, newPos int, err error) { p, _, err := rlp.List(payload, pos) if err != nil { diff --git a/types/txn_packets_decode.go b/types/txn_packets_decode.go new file mode 100644 index 000000000..34721f825 --- /dev/null +++ b/types/txn_packets_decode.go @@ -0,0 +1,30 @@ +package types + +import ( + "errors" + "fmt" + + "github.com/ledgerwatch/erigon-lib/rlp" +) + +func DecodeTransactions(dec *rlp.Decoder, ctx *TxParseContext, txSlots *TxSlots, validateHash func([]byte) error) (err error) { + i := 0 + err = dec.ForList(func(d *rlp.Decoder) error { + txSlots.Resize(uint(i + 1)) + txSlots.Txs[i] = &TxSlot{} + err = ctx.DecodeTransaction(d, txSlots.Txs[i], txSlots.Senders.At(i), true /* hasEnvelope */, true /* wrappedWithBlobs */, validateHash) + if err != nil { + if errors.Is(err, ErrRejected) { + txSlots.Resize(uint(i)) + return nil + } + return fmt.Errorf("elem: %w", err) + } + i = i + 1 + return nil + }) + if err != nil { + return err + } + return nil +} diff --git a/types/txn_packets_test.go b/types/txn_packets_test.go index a21c523fd..9d23ac291 100644 --- a/types/txn_packets_test.go +++ b/types/txn_packets_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/require" "github.com/ledgerwatch/erigon-lib/common/hexutility" + "github.com/ledgerwatch/erigon-lib/rlp" ) var hashParseTests = []struct { @@ -205,7 +206,7 @@ func TestTransactionsPacket(t *testing.T) { ctx := NewTxParseContext(*uint256.NewInt(tt.chainID)) slots := &TxSlots{} - _, err := ParseTransactions(encodeBuf, 0, ctx, slots, nil) + err := DecodeTransactions(rlp.NewDecoder(encodeBuf), ctx, slots, nil) require.NoError(err) require.Equal(len(tt.txs), len(slots.Txs)) for i, txn := range tt.txs { @@ -223,7 +224,7 @@ func TestTransactionsPacket(t *testing.T) { chainID := uint256.NewInt(tt.chainID) ctx := NewTxParseContext(*chainID) slots := &TxSlots{} - _, err := ParseTransactions(encodeBuf, 0, ctx, slots, func(bytes []byte) error { return ErrRejected }) + err := DecodeTransactions(rlp.NewDecoder(encodeBuf), ctx, slots, func(bytes []byte) error { return ErrRejected }) require.NoError(err) require.Equal(0, len(slots.Txs)) require.Equal(0, slots.Senders.Len()) diff --git a/types/txn_test.go b/types/txn_test.go index b5e992d49..7cc616488 100644 --- a/types/txn_test.go +++ b/types/txn_test.go @@ -17,7 +17,6 @@ package types import ( - "bytes" "crypto/rand" "strconv" "testing" @@ -29,6 +28,7 @@ import ( "github.com/ledgerwatch/erigon-lib/common/fixedgas" "github.com/ledgerwatch/erigon-lib/common/hexutility" + "github.com/ledgerwatch/erigon-lib/rlp" ) func TestParseTransactionRLP(t *testing.T) { @@ -42,26 +42,21 @@ func TestParseTransactionRLP(t *testing.T) { tt := tt t.Run(strconv.Itoa(i), func(t *testing.T) { payload := hexutility.MustDecodeHex(tt.PayloadStr) - parseEnd, err := ctx.ParseTransaction(payload, 0, tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) - require.NoError(err) - require.Equal(len(payload), parseEnd) + dec := rlp.NewDecoder(payload) + err := ctx.DecodeTransaction(dec, tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + require.NoError(err, tt.PayloadStr) + require.Equal(len(payload), dec.Offset()) if tt.SignHashStr != "" { signHash := hexutility.MustDecodeHex(tt.SignHashStr) - if !bytes.Equal(signHash, ctx.Sighash[:]) { - t.Errorf("signHash expected %x, got %x", signHash, ctx.Sighash) - } + assert.EqualValues(t, signHash, ctx.Sighash[:], "sighash") } if tt.IdHashStr != "" { idHash := hexutility.MustDecodeHex(tt.IdHashStr) - if !bytes.Equal(idHash, tx.IDHash[:]) { - t.Errorf("IdHash expected %x, got %x", idHash, tx.IDHash) - } + assert.EqualValues(t, idHash, tx.IDHash[:], "idhash") } if tt.SenderStr != "" { expectSender := hexutility.MustDecodeHex(tt.SenderStr) - if !bytes.Equal(expectSender, txSender[:]) { - t.Errorf("expectSender expected %x, got %x", expectSender, txSender) - } + assert.EqualValues(t, expectSender, txSender[:], "sender") } require.Equal(tt.Nonce, tx.Nonce) }) @@ -77,19 +72,19 @@ func TestTransactionSignatureValidity1(t *testing.T) { tx, txSender := &TxSlot{}, [20]byte{} validTxn := hexutility.MustDecodeHex("f83f800182520894095e7baea6a6c7c4c2dfeb977efac326af552d870b801ba048b55bfa915ac795c431978d8a6a992b628d557da5ff759b307d495a3664935301") - _, err := ctx.ParseTransaction(validTxn, 0, tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + err := ctx.DecodeTransaction(rlp.NewDecoder(validTxn), tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) assert.NoError(t, err) preEip2Txn := hexutility.MustDecodeHex("f85f800182520894095e7baea6a6c7c4c2dfeb977efac326af552d870b801ba048b55bfa915ac795c431978d8a6a992b628d557da5ff759b307d495a36649353a07fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a1") - _, err = ctx.ParseTransaction(preEip2Txn, 0, tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + err = ctx.DecodeTransaction(rlp.NewDecoder(validTxn), tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) assert.NoError(t, err) // Now enforce EIP-2 ctx.WithAllowPreEip2s(false) - _, err = ctx.ParseTransaction(validTxn, 0, tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + err = ctx.DecodeTransaction(rlp.NewDecoder(validTxn), tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) assert.NoError(t, err) - _, err = ctx.ParseTransaction(preEip2Txn, 0, tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + err = ctx.DecodeTransaction(rlp.NewDecoder(preEip2Txn), tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) assert.Error(t, err) } @@ -98,13 +93,13 @@ func TestTransactionSignatureValidity2(t *testing.T) { chainId := new(uint256.Int).SetUint64(5) ctx := NewTxParseContext(*chainId) slot, sender := &TxSlot{}, [20]byte{} - rlp := hexutility.MustDecodeHex("02f8720513844190ab00848321560082520894cab441d2f45a3fee83d15c6b6b6c36a139f55b6288054607fc96a6000080c001a0dffe4cb5651e663d0eac8c4d002de734dd24db0f1109b062d17da290a133cc02a0913fb9f53f7a792bcd9e4d7cced1b8545d1ab82c77432b0bc2e9384ba6c250c5") - _, err := ctx.ParseTransaction(rlp, 0, slot, sender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + xs := hexutility.MustDecodeHex("02f8720513844190ab00848321560082520894cab441d2f45a3fee83d15c6b6b6c36a139f55b6288054607fc96a6000080c001a0dffe4cb5651e663d0eac8c4d002de734dd24db0f1109b062d17da290a133cc02a0913fb9f53f7a792bcd9e4d7cced1b8545d1ab82c77432b0bc2e9384ba6c250c5") + err := ctx.DecodeTransaction(rlp.NewDecoder(xs), slot, sender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) assert.Error(t, err) // Only legacy transactions can happen before EIP-2 ctx.WithAllowPreEip2s(true) - _, err = ctx.ParseTransaction(rlp, 0, slot, sender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + err = ctx.DecodeTransaction(rlp.NewDecoder(xs), slot, sender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) assert.Error(t, err) } @@ -205,9 +200,10 @@ func TestBlobTxParsing(t *testing.T) { require.NoError(t, err) assert.Equal(t, BlobTxType, txType) - p, err := ctx.ParseTransaction(bodyEnvelope, 0, &thinTx, nil, hasEnvelope, wrappedWithBlobs, nil) + dec := rlp.NewDecoder(bodyEnvelope) + err = ctx.DecodeTransaction(dec, &thinTx, nil, hasEnvelope, wrappedWithBlobs, nil) require.NoError(t, err) - assert.Equal(t, len(bodyEnvelope), p) + assert.Equal(t, len(bodyEnvelope), dec.Offset()) assert.Equal(t, len(bodyEnvelope), int(thinTx.Size)) assert.Equal(t, bodyEnvelope[3:], thinTx.Rlp) assert.Equal(t, BlobTxType, thinTx.Type) @@ -258,9 +254,10 @@ func TestBlobTxParsing(t *testing.T) { require.NoError(t, err) assert.Equal(t, BlobTxType, txType) - p, err = ctx.ParseTransaction(wrapperRlp, 0, &fatTx, nil, hasEnvelope, wrappedWithBlobs, nil) + dec = rlp.NewDecoder(wrapperRlp) + err = ctx.DecodeTransaction(dec, &fatTx, nil, hasEnvelope, wrappedWithBlobs, nil) require.NoError(t, err) - assert.Equal(t, len(wrapperRlp), p) + assert.Equal(t, len(wrapperRlp), dec.Offset()) assert.Equal(t, len(wrapperRlp), int(fatTx.Size)) assert.Equal(t, wrapperRlp, fatTx.Rlp) assert.Equal(t, BlobTxType, fatTx.Type) @@ -274,7 +271,7 @@ func TestBlobTxParsing(t *testing.T) { assert.Equal(t, thinTx.AlAddrCount, fatTx.AlAddrCount) assert.Equal(t, thinTx.AlStorCount, fatTx.AlStorCount) assert.Equal(t, thinTx.Gas, fatTx.Gas) - assert.Equal(t, thinTx.IDHash, fatTx.IDHash) + assert.Equal(t, thinTx.IDHash, fatTx.IDHash, "idhash") assert.Equal(t, thinTx.Creation, fatTx.Creation) assert.Equal(t, thinTx.BlobFeeCap, fatTx.BlobFeeCap) assert.Equal(t, thinTx.BlobHashes, fatTx.BlobHashes) diff --git a/types/txn_types_fuzz_test.go b/types/txn_types_fuzz_test.go index 1f43a695d..110498c19 100644 --- a/types/txn_types_fuzz_test.go +++ b/types/txn_types_fuzz_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/ledgerwatch/erigon-lib/common/u256" + "github.com/ledgerwatch/erigon-lib/rlp" ) // golang.org/s/draft-fuzzing-design @@ -27,6 +28,6 @@ func FuzzParseTx(f *testing.F) { ctx := NewTxParseContext(*u256.N1) txn := &TxSlot{} sender := make([]byte, 20) - _, _ = ctx.ParseTransaction(in, pos, txn, sender, false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + ctx.DecodeTransaction(rlp.NewDecoder(in), txn, sender, false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) }) }