From e1a641669ff06bf00a96d39874ac374705e40bc8 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 07:13:10 -0500 Subject: [PATCH 01/34] wop --- kv/keydb/kv_remote.go | 732 ++++++++++++++++++++++++++++++++++++++++++ rlp/encoder.go | 191 +++++++++++ rlp/util.go | 34 ++ 3 files changed, 957 insertions(+) create mode 100644 kv/keydb/kv_remote.go create mode 100644 rlp/encoder.go create mode 100644 rlp/util.go diff --git a/kv/keydb/kv_remote.go b/kv/keydb/kv_remote.go new file mode 100644 index 000000000..7c7e5424c --- /dev/null +++ b/kv/keydb/kv_remote.go @@ -0,0 +1,732 @@ +/* + Copyright 2021 Erigon contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package remotedb + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "runtime" + + "github.com/ledgerwatch/erigon-lib/gointerfaces" + "github.com/ledgerwatch/erigon-lib/gointerfaces/grpcutil" + "github.com/ledgerwatch/erigon-lib/gointerfaces/remote" + "github.com/ledgerwatch/erigon-lib/kv" + "github.com/ledgerwatch/erigon-lib/kv/iter" + "github.com/ledgerwatch/erigon-lib/kv/order" + "github.com/ledgerwatch/log/v3" + "golang.org/x/sync/semaphore" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/ledgerwatch/erigon-lib/kv/mdbx" + + "github.com/redis/go-redis/v9" +) + +// generate the messages and services +type remoteOpts struct { + remoteKV remote.KVClient + log log.Logger + bucketsCfg mdbx.TableCfgFunc + DialAddress string + version gointerfaces.Version +} + +var _ kv.TemporalTx = (*tx)(nil) + +type DB struct { + remoteKV remote.KVClient + log log.Logger + buckets kv.TableCfg + roTxsLimiter *semaphore.Weighted + opts remoteOpts + remote *redis.ClusterClient +} + +type tx struct { + stream remote.KV_TxClient + ctx context.Context + streamCancelFn context.CancelFunc + db *DB + statelessCursors map[string]kv.Cursor + cursors []*remoteCursor + streams []kv.Closer + viewID, id uint64 + streamingRequested bool +} + +type remoteCursor struct { + ctx context.Context + stream remote.KV_TxClient + tx *tx + bucketName string + bucketCfg kv.TableCfgItem + id uint32 +} + +type remoteCursorDupSort struct { + *remoteCursor +} + +func (opts remoteOpts) ReadOnly() remoteOpts { + return opts +} + +func (opts remoteOpts) WithBucketsConfig(f mdbx.TableCfgFunc) remoteOpts { + opts.bucketsCfg = f + return opts +} + +func (opts remoteOpts) Open() (*DB, error) { + targetSemCount := int64(runtime.GOMAXPROCS(-1)) - 1 + if targetSemCount <= 1 { + targetSemCount = 2 + } + + db := &DB{ + opts: opts, + remoteKV: opts.remoteKV, + log: log.New("remote_db", opts.DialAddress), + buckets: kv.TableCfg{}, + roTxsLimiter: semaphore.NewWeighted(targetSemCount), // 1 less than max to allow unlocking + } + customBuckets := opts.bucketsCfg(kv.ChaindataTablesCfg) + for name, cfg := range customBuckets { // copy map to avoid changing global variable + db.buckets[name] = cfg + } + + return db, nil +} + +func (opts remoteOpts) MustOpen() kv.RwDB { + db, err := opts.Open() + if err != nil { + panic(err) + } + return db +} + +// NewRemote defines new remove KV connection (without actually opening it) +// version parameters represent the version the KV client is expecting, +// compatibility check will be performed when the KV connection opens +func NewRemote(v gointerfaces.Version, logger log.Logger, remoteKV remote.KVClient) remoteOpts { + return remoteOpts{bucketsCfg: mdbx.WithChaindataTables, version: v, log: logger, remoteKV: remoteKV} +} + +func (db *DB) PageSize() uint64 { panic("not implemented") } +func (db *DB) ReadOnly() bool { return true } +func (db *DB) AllTables() kv.TableCfg { return db.buckets } + +func (db *DB) EnsureVersionCompatibility() bool { + versionReply, err := db.remoteKV.Version(context.Background(), &emptypb.Empty{}, grpc.WaitForReady(true)) + if err != nil { + db.log.Error("getting Version", "error", err) + return false + } + if !gointerfaces.EnsureVersion(db.opts.version, versionReply) { + db.log.Error("incompatible interface versions", "client", db.opts.version.String(), + "server", fmt.Sprintf("%d.%d.%d", versionReply.Major, versionReply.Minor, versionReply.Patch)) + return false + } + db.log.Info("interfaces compatible", "client", db.opts.version.String(), + "server", fmt.Sprintf("%d.%d.%d", versionReply.Major, versionReply.Minor, versionReply.Patch)) + return true +} + +func (db *DB) Close() {} + +func (db *DB) BeginRo(ctx context.Context) (txn kv.Tx, err error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if semErr := db.roTxsLimiter.Acquire(ctx, 1); semErr != nil { + return nil, semErr + } + + defer func() { + // ensure we release the semaphore on error + if txn == nil { + db.roTxsLimiter.Release(1) + } + }() + + streamCtx, streamCancelFn := context.WithCancel(ctx) // We create child context for the stream so we can cancel it to prevent leak + stream, err := db.remoteKV.Tx(streamCtx) + if err != nil { + streamCancelFn() + return nil, err + } + msg, err := stream.Recv() + if err != nil { + streamCancelFn() + return nil, err + } + return &tx{ctx: ctx, db: db, stream: stream, streamCancelFn: streamCancelFn, viewID: msg.ViewId, id: msg.TxId}, nil +} +func (db *DB) BeginTemporalRo(ctx context.Context) (kv.TemporalTx, error) { + t, err := db.BeginRo(ctx) + if err != nil { + return nil, err + } + return t.(kv.TemporalTx), nil +} +func (db *DB) BeginRw(ctx context.Context) (kv.RwTx, error) { + return nil, fmt.Errorf("remote db provider doesn't support .BeginRw method") +} +func (db *DB) BeginRwNosync(ctx context.Context) (kv.RwTx, error) { + return nil, fmt.Errorf("remote db provider doesn't support .BeginRw method") +} +func (db *DB) BeginTemporalRw(ctx context.Context) (kv.RwTx, error) { + return nil, fmt.Errorf("remote db provider doesn't support .BeginTemporalRw method") +} +func (db *DB) BeginTemporalRwNosync(ctx context.Context) (kv.RwTx, error) { + return nil, fmt.Errorf("remote db provider doesn't support .BeginTemporalRwNosync method") +} + +func (db *DB) View(ctx context.Context, f func(tx kv.Tx) error) (err error) { + tx, err := db.BeginRo(ctx) + if err != nil { + return err + } + defer tx.Rollback() + return f(tx) +} +func (db *DB) ViewTemporal(ctx context.Context, f func(tx kv.TemporalTx) error) (err error) { + tx, err := db.BeginTemporalRo(ctx) + if err != nil { + return err + } + defer tx.Rollback() + return f(tx) +} + +func (db *DB) Update(ctx context.Context, f func(tx kv.RwTx) error) (err error) { + return fmt.Errorf("remote db provider doesn't support .Update method") +} +func (db *DB) UpdateNosync(ctx context.Context, f func(tx kv.RwTx) error) (err error) { + return fmt.Errorf("remote db provider doesn't support .UpdateNosync method") +} + +func (tx *tx) ViewID() uint64 { return tx.viewID } +func (tx *tx) CollectMetrics() {} +func (tx *tx) IncrementSequence(bucket string, amount uint64) (uint64, error) { + panic("not implemented yet") +} +func (tx *tx) ReadSequence(bucket string) (uint64, error) { + panic("not implemented yet") +} +func (tx *tx) Append(bucket string, k, v []byte) error { panic("no write methods") } +func (tx *tx) AppendDup(bucket string, k, v []byte) error { panic("no write methods") } + +func (tx *tx) Commit() error { + panic("remote db is read-only") +} + +func (tx *tx) Rollback() { + // don't close opened cursors - just close stream, server will cleanup everything well + tx.closeGrpcStream() + tx.db.roTxsLimiter.Release(1) + for _, c := range tx.streams { + c.Close() + } +} +func (tx *tx) DBSize() (uint64, error) { panic("not implemented") } + +func (tx *tx) statelessCursor(bucket string) (kv.Cursor, error) { + if tx.statelessCursors == nil { + tx.statelessCursors = make(map[string]kv.Cursor) + } + c, ok := tx.statelessCursors[bucket] + if !ok { + var err error + c, err = tx.Cursor(bucket) + if err != nil { + return nil, err + } + tx.statelessCursors[bucket] = c + } + return c, nil +} + +func (tx *tx) BucketSize(name string) (uint64, error) { panic("not implemented") } + +func (tx *tx) ForEach(bucket string, fromPrefix []byte, walker func(k, v []byte) error) error { + it, err := tx.Range(bucket, fromPrefix, nil) + if err != nil { + return err + } + for it.HasNext() { + k, v, err := it.Next() + if err != nil { + return err + } + if err := walker(k, v); err != nil { + return err + } + } + return nil +} + +func (tx *tx) ForPrefix(bucket string, prefix []byte, walker func(k, v []byte) error) error { + it, err := tx.Prefix(bucket, prefix) + if err != nil { + return err + } + for it.HasNext() { + k, v, err := it.Next() + if err != nil { + return err + } + if err := walker(k, v); err != nil { + return err + } + } + return nil +} + +// TODO: this must be deprecated +func (tx *tx) ForAmount(bucket string, fromPrefix []byte, amount uint32, walker func(k, v []byte) error) error { + if amount == 0 { + return nil + } + c, err := tx.Cursor(bucket) + if err != nil { + return err + } + defer c.Close() + + for k, v, err := c.Seek(fromPrefix); k != nil && amount > 0; k, v, err = c.Next() { + if err != nil { + return err + } + if err := walker(k, v); err != nil { + return err + } + amount-- + } + return nil +} + +func (tx *tx) GetOne(bucket string, k []byte) (val []byte, err error) { + c, err := tx.statelessCursor(bucket) + if err != nil { + return nil, err + } + _, val, err = c.SeekExact(k) + return val, err +} + +func (tx *tx) Has(bucket string, k []byte) (bool, error) { + c, err := tx.statelessCursor(bucket) + if err != nil { + return false, err + } + kk, _, err := c.Seek(k) + if err != nil { + return false, err + } + return bytes.Equal(k, kk), nil +} + +func (c *remoteCursor) SeekExact(k []byte) (key, val []byte, err error) { + return c.seekExact(k) +} + +func (c *remoteCursor) Prev() ([]byte, []byte, error) { + return c.prev() +} + +func (tx *tx) Cursor(bucket string) (kv.Cursor, error) { + b := tx.db.buckets[bucket] + c := &remoteCursor{tx: tx, ctx: tx.ctx, bucketName: bucket, bucketCfg: b, stream: tx.stream} + tx.cursors = append(tx.cursors, c) + if err := c.stream.Send(&remote.Cursor{Op: remote.Op_OPEN, BucketName: c.bucketName}); err != nil { + return nil, err + } + msg, err := c.stream.Recv() + if err != nil { + return nil, err + } + c.id = msg.CursorId + return c, nil +} + +func (tx *tx) ListBuckets() ([]string, error) { + return nil, fmt.Errorf("function ListBuckets is not implemented for remoteTx") +} + +// func (c *remoteCursor) Put(k []byte, v []byte) error { panic("not supported") } +// func (c *remoteCursor) PutNoOverwrite(k []byte, v []byte) error { panic("not supported") } +// func (c *remoteCursor) Append(k []byte, v []byte) error { panic("not supported") } +// func (c *remoteCursor) Delete(k []byte) error { panic("not supported") } +// func (c *remoteCursor) DeleteCurrent() error { panic("not supported") } +func (c *remoteCursor) Count() (uint64, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_COUNT}); err != nil { + return 0, err + } + pair, err := c.stream.Recv() + if err != nil { + return 0, err + } + return binary.BigEndian.Uint64(pair.V), nil + +} + +func (c *remoteCursor) first() ([]byte, []byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_FIRST}); err != nil { + return []byte{}, nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return []byte{}, nil, err + } + return pair.K, pair.V, nil +} + +func (c *remoteCursor) next() ([]byte, []byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_NEXT}); err != nil { + return []byte{}, nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return []byte{}, nil, err + } + return pair.K, pair.V, nil +} +func (c *remoteCursor) nextDup() ([]byte, []byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_NEXT_DUP}); err != nil { + return []byte{}, nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return []byte{}, nil, err + } + return pair.K, pair.V, nil +} +func (c *remoteCursor) nextNoDup() ([]byte, []byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_NEXT_NO_DUP}); err != nil { + return []byte{}, nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return []byte{}, nil, err + } + return pair.K, pair.V, nil +} +func (c *remoteCursor) prev() ([]byte, []byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_PREV}); err != nil { + return []byte{}, nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return []byte{}, nil, err + } + return pair.K, pair.V, nil +} +func (c *remoteCursor) prevDup() ([]byte, []byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_PREV_DUP}); err != nil { + return []byte{}, nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return []byte{}, nil, err + } + return pair.K, pair.V, nil +} +func (c *remoteCursor) prevNoDup() ([]byte, []byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_PREV_NO_DUP}); err != nil { + return []byte{}, nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return []byte{}, nil, err + } + return pair.K, pair.V, nil +} +func (c *remoteCursor) last() ([]byte, []byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_LAST}); err != nil { + return []byte{}, nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return []byte{}, nil, err + } + return pair.K, pair.V, nil +} +func (c *remoteCursor) setRange(k []byte) ([]byte, []byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_SEEK, K: k}); err != nil { + return []byte{}, nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return []byte{}, nil, err + } + return pair.K, pair.V, nil +} +func (c *remoteCursor) seekExact(k []byte) ([]byte, []byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_SEEK_EXACT, K: k}); err != nil { + return []byte{}, nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return []byte{}, nil, err + } + return pair.K, pair.V, nil +} +func (c *remoteCursor) getBothRange(k, v []byte) ([]byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_SEEK_BOTH, K: k, V: v}); err != nil { + return nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return nil, err + } + return pair.V, nil +} +func (c *remoteCursor) seekBothExact(k, v []byte) ([]byte, []byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_SEEK_BOTH_EXACT, K: k, V: v}); err != nil { + return []byte{}, nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return []byte{}, nil, err + } + return pair.K, pair.V, nil +} +func (c *remoteCursor) firstDup() ([]byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_FIRST_DUP}); err != nil { + return nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return nil, err + } + return pair.V, nil +} +func (c *remoteCursor) lastDup() ([]byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_LAST_DUP}); err != nil { + return nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return nil, err + } + return pair.V, nil +} +func (c *remoteCursor) getCurrent() ([]byte, []byte, error) { + if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_CURRENT}); err != nil { + return []byte{}, nil, err + } + pair, err := c.stream.Recv() + if err != nil { + return []byte{}, nil, err + } + return pair.K, pair.V, nil +} + +func (c *remoteCursor) Current() ([]byte, []byte, error) { + return c.getCurrent() +} + +// Seek - doesn't start streaming (because much of code does only several .Seek calls without reading sequence of data) +// .Next() - does request streaming (if configured by user) +func (c *remoteCursor) Seek(seek []byte) ([]byte, []byte, error) { + return c.setRange(seek) +} + +func (c *remoteCursor) First() ([]byte, []byte, error) { + return c.first() +} + +// Next - returns next data element from server, request streaming (if configured by user) +func (c *remoteCursor) Next() ([]byte, []byte, error) { + return c.next() +} + +func (c *remoteCursor) Last() ([]byte, []byte, error) { + return c.last() +} + +func (tx *tx) closeGrpcStream() { + if tx.stream == nil { + return + } + defer tx.streamCancelFn() // hard cancel stream if graceful wasn't successful + + if tx.streamingRequested { + // if streaming is in progress, can't use `CloseSend` - because + // server will not read it right not - it busy with streaming data + // TODO: set flag 'tx.streamingRequested' to false when got terminator from server (nil key or os.EOF) + tx.streamCancelFn() + } else { + // try graceful close stream + err := tx.stream.CloseSend() + if err != nil { + doLog := !grpcutil.IsEndOfStream(err) + if doLog { + log.Warn("couldn't send msg CloseSend to server", "err", err) + } + } else { + _, err = tx.stream.Recv() + if err != nil { + doLog := !grpcutil.IsEndOfStream(err) + if doLog { + log.Warn("received unexpected error from server after CloseSend", "err", err) + } + } + } + } + tx.stream = nil + tx.streamingRequested = false +} + +func (c *remoteCursor) Close() { + if c.stream == nil { + return + } + st := c.stream + c.stream = nil + if err := st.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_CLOSE}); err == nil { + _, _ = st.Recv() + } +} + +func (tx *tx) CursorDupSort(bucket string) (kv.CursorDupSort, error) { + b := tx.db.buckets[bucket] + c := &remoteCursor{tx: tx, ctx: tx.ctx, bucketName: bucket, bucketCfg: b, stream: tx.stream} + tx.cursors = append(tx.cursors, c) + if err := c.stream.Send(&remote.Cursor{Op: remote.Op_OPEN_DUP_SORT, BucketName: c.bucketName}); err != nil { + return nil, err + } + msg, err := c.stream.Recv() + if err != nil { + return nil, err + } + c.id = msg.CursorId + return &remoteCursorDupSort{remoteCursor: c}, nil +} + +func (c *remoteCursorDupSort) SeekBothExact(k, v []byte) ([]byte, []byte, error) { + return c.seekBothExact(k, v) +} + +func (c *remoteCursorDupSort) SeekBothRange(k, v []byte) ([]byte, error) { + return c.getBothRange(k, v) +} + +func (c *remoteCursorDupSort) DeleteExact(k1, k2 []byte) error { panic("not supported") } +func (c *remoteCursorDupSort) AppendDup(k []byte, v []byte) error { panic("not supported") } +func (c *remoteCursorDupSort) PutNoDupData(k, v []byte) error { panic("not supported") } +func (c *remoteCursorDupSort) DeleteCurrentDuplicates() error { panic("not supported") } +func (c *remoteCursorDupSort) CountDuplicates() (uint64, error) { panic("not supported") } + +func (c *remoteCursorDupSort) FirstDup() ([]byte, error) { return c.firstDup() } +func (c *remoteCursorDupSort) NextDup() ([]byte, []byte, error) { return c.nextDup() } +func (c *remoteCursorDupSort) NextNoDup() ([]byte, []byte, error) { return c.nextNoDup() } +func (c *remoteCursorDupSort) PrevDup() ([]byte, []byte, error) { return c.prevDup() } +func (c *remoteCursorDupSort) PrevNoDup() ([]byte, []byte, error) { return c.prevNoDup() } +func (c *remoteCursorDupSort) LastDup() ([]byte, error) { return c.lastDup() } + +// Temporal Methods +func (tx *tx) DomainGetAsOf(name kv.Domain, k, k2 []byte, ts uint64) (v []byte, ok bool, err error) { + reply, err := tx.db.remoteKV.DomainGet(tx.ctx, &remote.DomainGetReq{TxId: tx.id, Table: string(name), K: k, K2: k2, Ts: ts}) + if err != nil { + return nil, false, err + } + return reply.V, reply.Ok, nil +} + +func (tx *tx) DomainGet(name kv.Domain, k, k2 []byte) (v []byte, ok bool, err error) { + reply, err := tx.db.remoteKV.DomainGet(tx.ctx, &remote.DomainGetReq{TxId: tx.id, Table: string(name), K: k, K2: k2, Latest: true}) + if err != nil { + return nil, false, err + } + return reply.V, reply.Ok, nil +} + +func (tx *tx) DomainRange(name kv.Domain, fromKey, toKey []byte, ts uint64, asc order.By, limit int) (it iter.KV, err error) { + return iter.PaginateKV(func(pageToken string) (keys, vals [][]byte, nextPageToken string, err error) { + reply, err := tx.db.remoteKV.DomainRange(tx.ctx, &remote.DomainRangeReq{TxId: tx.id, Table: string(name), FromKey: fromKey, ToKey: toKey, Ts: ts, OrderAscend: bool(asc), Limit: int64(limit)}) + if err != nil { + return nil, nil, "", err + } + return reply.Keys, reply.Values, reply.NextPageToken, nil + }), nil +} +func (tx *tx) HistoryGet(name kv.History, k []byte, ts uint64) (v []byte, ok bool, err error) { + reply, err := tx.db.remoteKV.HistoryGet(tx.ctx, &remote.HistoryGetReq{TxId: tx.id, Table: string(name), K: k, Ts: ts}) + if err != nil { + return nil, false, err + } + return reply.V, reply.Ok, nil +} +func (tx *tx) HistoryRange(name kv.History, fromTs, toTs int, asc order.By, limit int) (it iter.KV, err error) { + return iter.PaginateKV(func(pageToken string) (keys, vals [][]byte, nextPageToken string, err error) { + reply, err := tx.db.remoteKV.HistoryRange(tx.ctx, &remote.HistoryRangeReq{TxId: tx.id, Table: string(name), FromTs: int64(fromTs), ToTs: int64(toTs), OrderAscend: bool(asc), Limit: int64(limit)}) + if err != nil { + return nil, nil, "", err + } + return reply.Keys, reply.Values, reply.NextPageToken, nil + }), nil +} + +func (tx *tx) IndexRange(name kv.InvertedIdx, k []byte, fromTs, toTs int, asc order.By, limit int) (timestamps iter.U64, err error) { + return iter.PaginateU64(func(pageToken string) (arr []uint64, nextPageToken string, err error) { + req := &remote.IndexRangeReq{TxId: tx.id, Table: string(name), K: k, FromTs: int64(fromTs), ToTs: int64(toTs), OrderAscend: bool(asc), Limit: int64(limit)} + reply, err := tx.db.remoteKV.IndexRange(tx.ctx, req) + if err != nil { + return nil, "", err + } + return reply.Timestamps, reply.NextPageToken, nil + }), nil +} + +func (tx *tx) Prefix(table string, prefix []byte) (iter.KV, error) { + nextPrefix, ok := kv.NextSubtree(prefix) + if !ok { + return tx.Range(table, prefix, nil) + } + return tx.Range(table, prefix, nextPrefix) +} + +func (tx *tx) rangeOrderLimit(table string, fromPrefix, toPrefix []byte, asc order.By, limit int) (iter.KV, error) { + return iter.PaginateKV(func(pageToken string) (keys [][]byte, values [][]byte, nextPageToken string, err error) { + req := &remote.RangeReq{TxId: tx.id, Table: table, FromPrefix: fromPrefix, ToPrefix: toPrefix, OrderAscend: bool(asc), Limit: int64(limit)} + reply, err := tx.db.remoteKV.Range(tx.ctx, req) + if err != nil { + return nil, nil, "", err + } + return reply.Keys, reply.Values, reply.NextPageToken, nil + }), nil +} +func (tx *tx) Range(table string, fromPrefix, toPrefix []byte) (iter.KV, error) { + return tx.rangeOrderLimit(table, fromPrefix, toPrefix, order.Asc, -1) +} +func (tx *tx) RangeAscend(table string, fromPrefix, toPrefix []byte, limit int) (iter.KV, error) { + return tx.rangeOrderLimit(table, fromPrefix, toPrefix, order.Asc, limit) +} +func (tx *tx) RangeDescend(table string, fromPrefix, toPrefix []byte, limit int) (iter.KV, error) { + return tx.rangeOrderLimit(table, fromPrefix, toPrefix, order.Desc, limit) +} +func (tx *tx) RangeDupSort(table string, key []byte, fromPrefix, toPrefix []byte, asc order.By, limit int) (iter.KV, error) { + panic("not implemented yet") +} diff --git a/rlp/encoder.go b/rlp/encoder.go new file mode 100644 index 000000000..3eb56b269 --- /dev/null +++ b/rlp/encoder.go @@ -0,0 +1,191 @@ +package rlp + +import "golang.org/x/exp/constraints" + +type Encoder struct { + buf []byte +} + +func NewEncoder(buf []byte) *Encoder { + return &Encoder{ + buf: buf, + } +} + +func (e *Encoder) Bytes() []byte { + return e.buf +} +func (e *Encoder) Reset(b []byte) { + e.buf = b +} + +func (e *Encoder) WriteByte(p byte) (err error) { + e.buf = append(e.buf, p) + return nil +} + +func (e *Encoder) Write(p []byte) (n int, err error) { + e.buf = append(e.buf, p...) + return len(p), nil +} + +func (e *Encoder) WriteString(str []byte) { + if len(str) > 55 { + e.WriteLongString(str) + } + e.WriteShortString(str) +} + +func (e *Encoder) WriteShortString(str []byte) { + e.WriteByte(TokenShortString.Plus(byte(len(str)))) + e.Write(str) + return +} + +func (e *Encoder) WriteLongString(str []byte) { + // write the indicator token + e.WriteByte(byte(TokenLongString)) + // write the integer, knowing that we appended n bytes + n := putUint(e, len(str)) + // so we knw the indicator token was n+1 bytes ago. + e.buf[len(e.buf)-(int(n)+1)] += n + // and now add the actual length + e.buf = append(e.buf, str...) + return +} + +func (e *Encoder) WriteList(items ...func(i *Encoder)) { + e.writeList(true, items...) +} + +// WriteShortList will assume that your list payload is more than 55 bytes long +func (e *Encoder) WriteShortList(items ...func(i *Encoder)) { + e.buf = append(e.buf, TokenShortList.Plus(byte(len(items)))) + for _, v := range items { + v(e) + } + return +} + +// WriteLongList will assume that your list payload is more than 55 bytes long, and do no validation as such +func (e *Encoder) WriteLongList(items ...func(i *Encoder)) { + e.writeList(false, items...) +} + +// writeList will first attempt to write a long list with the dat +// if validate is false, it will just format it like the length is above 55 +// if validate is true, it will format it like it is a shrot list +func (e *Encoder) writeList(validate bool, items ...func(i *Encoder)) { + // write the indicator token + e.buf = append(e.buf, byte(TokenLongList)) + // now pad 8 bytes + e.buf = append(e.buf, make([]byte, 8)...) + // record the length before encoding items + startLength := len(e.buf) + // now write all the items + for _, v := range items { + v(e) + } + // the size is the difference in the lengths now + dataSize := len(e.buf) - startLength + if dataSize <= 55 && validate { + // oh it's actually a short string! awkward. let's set that then. + e.buf[startLength-9] = TokenShortList.Plus(byte(len(items))) + // and then copy the data over + copy(e.buf[startLength-8:], e.buf[startLength:startLength+dataSize]) + // and now set the new size + e.buf = e.buf[:startLength+dataSize-8] + // we are done, return + return + } + // ok, so it's a long string. + // create a new encoder centered at startLength - 8 + enc := NewEncoder(e.buf[startLength-8:]) + // now write using that encoder the size + n := putUint(enc, dataSize) + // and update the token, which we know is at startLength-9 + e.buf[startLength-9] += n + // the shift to perform now is 8 - n. + shift := int(8 - n) + // if there is a positive shift, then we must perform the shift + if shift > 0 { + // copy the data + copy(e.buf[startLength-shift:], e.buf[startLength:startLength+dataSize]) + // set the new length + e.buf = e.buf[:startLength-shift+dataSize] + } + return +} + +func putUint[T constraints.Integer](e *Encoder, t T) (size byte) { + i := uint64(t) + switch { + case i < (1 << 8): + e.buf = append(e.buf, byte(i)) + return 1 + case i < (1 << 16): + e.buf = append(e.buf, + byte(i>>8), + byte(i), + ) + return 2 + case i < (1 << 24): + + e.buf = append(e.buf, + byte(i>>16), + byte(i>>8), + byte(i), + ) + return 3 + case i < (1 << 32): + e.buf = append(e.buf, + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + return 4 + case i < (1 << 40): + e.buf = append(e.buf, + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + return 5 + case i < (1 << 48): + e.buf = append(e.buf, + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + return 6 + case i < (1 << 56): + e.buf = append(e.buf, + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + return 7 + default: + e.buf = append(e.buf, + byte(i>>56), + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + return 8 + } +} diff --git a/rlp/util.go b/rlp/util.go new file mode 100644 index 000000000..784c297d5 --- /dev/null +++ b/rlp/util.go @@ -0,0 +1,34 @@ +package rlp + +type Token byte + +func (T Token) Plus(n byte) byte { + return byte(T) + n +} + +// This token table can also be used for offsets. how cool! +const ( + TokenDecimal Token = 0x00 + TokenShortString Token = 0x80 + TokenLongString Token = 0xb7 + TokenShortList Token = 0xc0 + TokenLongList Token = 0xf7 + + TokenUnknown Token = 0xff +) + +func identifyToken(b byte) Token { + switch { + case b >= 0 && b <= 127: + return TokenDecimal + case b >= 128 && b <= 183: + return TokenShortString + case b >= 184 && b <= 191: + return TokenLongString + case b >= 192 && b <= 247: + return TokenShortList + case b >= 248 && b <= 255: + return TokenLongList + } + return TokenUnknown +} From b1eddb1134b42882bf59209f3cec761d55ed6fb7 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 07:13:34 -0500 Subject: [PATCH 02/34] oops --- kv/keydb/kv_remote.go | 732 ------------------------------------------ 1 file changed, 732 deletions(-) delete mode 100644 kv/keydb/kv_remote.go diff --git a/kv/keydb/kv_remote.go b/kv/keydb/kv_remote.go deleted file mode 100644 index 7c7e5424c..000000000 --- a/kv/keydb/kv_remote.go +++ /dev/null @@ -1,732 +0,0 @@ -/* - Copyright 2021 Erigon contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package remotedb - -import ( - "bytes" - "context" - "encoding/binary" - "fmt" - "runtime" - - "github.com/ledgerwatch/erigon-lib/gointerfaces" - "github.com/ledgerwatch/erigon-lib/gointerfaces/grpcutil" - "github.com/ledgerwatch/erigon-lib/gointerfaces/remote" - "github.com/ledgerwatch/erigon-lib/kv" - "github.com/ledgerwatch/erigon-lib/kv/iter" - "github.com/ledgerwatch/erigon-lib/kv/order" - "github.com/ledgerwatch/log/v3" - "golang.org/x/sync/semaphore" - "google.golang.org/grpc" - "google.golang.org/protobuf/types/known/emptypb" - - "github.com/ledgerwatch/erigon-lib/kv/mdbx" - - "github.com/redis/go-redis/v9" -) - -// generate the messages and services -type remoteOpts struct { - remoteKV remote.KVClient - log log.Logger - bucketsCfg mdbx.TableCfgFunc - DialAddress string - version gointerfaces.Version -} - -var _ kv.TemporalTx = (*tx)(nil) - -type DB struct { - remoteKV remote.KVClient - log log.Logger - buckets kv.TableCfg - roTxsLimiter *semaphore.Weighted - opts remoteOpts - remote *redis.ClusterClient -} - -type tx struct { - stream remote.KV_TxClient - ctx context.Context - streamCancelFn context.CancelFunc - db *DB - statelessCursors map[string]kv.Cursor - cursors []*remoteCursor - streams []kv.Closer - viewID, id uint64 - streamingRequested bool -} - -type remoteCursor struct { - ctx context.Context - stream remote.KV_TxClient - tx *tx - bucketName string - bucketCfg kv.TableCfgItem - id uint32 -} - -type remoteCursorDupSort struct { - *remoteCursor -} - -func (opts remoteOpts) ReadOnly() remoteOpts { - return opts -} - -func (opts remoteOpts) WithBucketsConfig(f mdbx.TableCfgFunc) remoteOpts { - opts.bucketsCfg = f - return opts -} - -func (opts remoteOpts) Open() (*DB, error) { - targetSemCount := int64(runtime.GOMAXPROCS(-1)) - 1 - if targetSemCount <= 1 { - targetSemCount = 2 - } - - db := &DB{ - opts: opts, - remoteKV: opts.remoteKV, - log: log.New("remote_db", opts.DialAddress), - buckets: kv.TableCfg{}, - roTxsLimiter: semaphore.NewWeighted(targetSemCount), // 1 less than max to allow unlocking - } - customBuckets := opts.bucketsCfg(kv.ChaindataTablesCfg) - for name, cfg := range customBuckets { // copy map to avoid changing global variable - db.buckets[name] = cfg - } - - return db, nil -} - -func (opts remoteOpts) MustOpen() kv.RwDB { - db, err := opts.Open() - if err != nil { - panic(err) - } - return db -} - -// NewRemote defines new remove KV connection (without actually opening it) -// version parameters represent the version the KV client is expecting, -// compatibility check will be performed when the KV connection opens -func NewRemote(v gointerfaces.Version, logger log.Logger, remoteKV remote.KVClient) remoteOpts { - return remoteOpts{bucketsCfg: mdbx.WithChaindataTables, version: v, log: logger, remoteKV: remoteKV} -} - -func (db *DB) PageSize() uint64 { panic("not implemented") } -func (db *DB) ReadOnly() bool { return true } -func (db *DB) AllTables() kv.TableCfg { return db.buckets } - -func (db *DB) EnsureVersionCompatibility() bool { - versionReply, err := db.remoteKV.Version(context.Background(), &emptypb.Empty{}, grpc.WaitForReady(true)) - if err != nil { - db.log.Error("getting Version", "error", err) - return false - } - if !gointerfaces.EnsureVersion(db.opts.version, versionReply) { - db.log.Error("incompatible interface versions", "client", db.opts.version.String(), - "server", fmt.Sprintf("%d.%d.%d", versionReply.Major, versionReply.Minor, versionReply.Patch)) - return false - } - db.log.Info("interfaces compatible", "client", db.opts.version.String(), - "server", fmt.Sprintf("%d.%d.%d", versionReply.Major, versionReply.Minor, versionReply.Patch)) - return true -} - -func (db *DB) Close() {} - -func (db *DB) BeginRo(ctx context.Context) (txn kv.Tx, err error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - if semErr := db.roTxsLimiter.Acquire(ctx, 1); semErr != nil { - return nil, semErr - } - - defer func() { - // ensure we release the semaphore on error - if txn == nil { - db.roTxsLimiter.Release(1) - } - }() - - streamCtx, streamCancelFn := context.WithCancel(ctx) // We create child context for the stream so we can cancel it to prevent leak - stream, err := db.remoteKV.Tx(streamCtx) - if err != nil { - streamCancelFn() - return nil, err - } - msg, err := stream.Recv() - if err != nil { - streamCancelFn() - return nil, err - } - return &tx{ctx: ctx, db: db, stream: stream, streamCancelFn: streamCancelFn, viewID: msg.ViewId, id: msg.TxId}, nil -} -func (db *DB) BeginTemporalRo(ctx context.Context) (kv.TemporalTx, error) { - t, err := db.BeginRo(ctx) - if err != nil { - return nil, err - } - return t.(kv.TemporalTx), nil -} -func (db *DB) BeginRw(ctx context.Context) (kv.RwTx, error) { - return nil, fmt.Errorf("remote db provider doesn't support .BeginRw method") -} -func (db *DB) BeginRwNosync(ctx context.Context) (kv.RwTx, error) { - return nil, fmt.Errorf("remote db provider doesn't support .BeginRw method") -} -func (db *DB) BeginTemporalRw(ctx context.Context) (kv.RwTx, error) { - return nil, fmt.Errorf("remote db provider doesn't support .BeginTemporalRw method") -} -func (db *DB) BeginTemporalRwNosync(ctx context.Context) (kv.RwTx, error) { - return nil, fmt.Errorf("remote db provider doesn't support .BeginTemporalRwNosync method") -} - -func (db *DB) View(ctx context.Context, f func(tx kv.Tx) error) (err error) { - tx, err := db.BeginRo(ctx) - if err != nil { - return err - } - defer tx.Rollback() - return f(tx) -} -func (db *DB) ViewTemporal(ctx context.Context, f func(tx kv.TemporalTx) error) (err error) { - tx, err := db.BeginTemporalRo(ctx) - if err != nil { - return err - } - defer tx.Rollback() - return f(tx) -} - -func (db *DB) Update(ctx context.Context, f func(tx kv.RwTx) error) (err error) { - return fmt.Errorf("remote db provider doesn't support .Update method") -} -func (db *DB) UpdateNosync(ctx context.Context, f func(tx kv.RwTx) error) (err error) { - return fmt.Errorf("remote db provider doesn't support .UpdateNosync method") -} - -func (tx *tx) ViewID() uint64 { return tx.viewID } -func (tx *tx) CollectMetrics() {} -func (tx *tx) IncrementSequence(bucket string, amount uint64) (uint64, error) { - panic("not implemented yet") -} -func (tx *tx) ReadSequence(bucket string) (uint64, error) { - panic("not implemented yet") -} -func (tx *tx) Append(bucket string, k, v []byte) error { panic("no write methods") } -func (tx *tx) AppendDup(bucket string, k, v []byte) error { panic("no write methods") } - -func (tx *tx) Commit() error { - panic("remote db is read-only") -} - -func (tx *tx) Rollback() { - // don't close opened cursors - just close stream, server will cleanup everything well - tx.closeGrpcStream() - tx.db.roTxsLimiter.Release(1) - for _, c := range tx.streams { - c.Close() - } -} -func (tx *tx) DBSize() (uint64, error) { panic("not implemented") } - -func (tx *tx) statelessCursor(bucket string) (kv.Cursor, error) { - if tx.statelessCursors == nil { - tx.statelessCursors = make(map[string]kv.Cursor) - } - c, ok := tx.statelessCursors[bucket] - if !ok { - var err error - c, err = tx.Cursor(bucket) - if err != nil { - return nil, err - } - tx.statelessCursors[bucket] = c - } - return c, nil -} - -func (tx *tx) BucketSize(name string) (uint64, error) { panic("not implemented") } - -func (tx *tx) ForEach(bucket string, fromPrefix []byte, walker func(k, v []byte) error) error { - it, err := tx.Range(bucket, fromPrefix, nil) - if err != nil { - return err - } - for it.HasNext() { - k, v, err := it.Next() - if err != nil { - return err - } - if err := walker(k, v); err != nil { - return err - } - } - return nil -} - -func (tx *tx) ForPrefix(bucket string, prefix []byte, walker func(k, v []byte) error) error { - it, err := tx.Prefix(bucket, prefix) - if err != nil { - return err - } - for it.HasNext() { - k, v, err := it.Next() - if err != nil { - return err - } - if err := walker(k, v); err != nil { - return err - } - } - return nil -} - -// TODO: this must be deprecated -func (tx *tx) ForAmount(bucket string, fromPrefix []byte, amount uint32, walker func(k, v []byte) error) error { - if amount == 0 { - return nil - } - c, err := tx.Cursor(bucket) - if err != nil { - return err - } - defer c.Close() - - for k, v, err := c.Seek(fromPrefix); k != nil && amount > 0; k, v, err = c.Next() { - if err != nil { - return err - } - if err := walker(k, v); err != nil { - return err - } - amount-- - } - return nil -} - -func (tx *tx) GetOne(bucket string, k []byte) (val []byte, err error) { - c, err := tx.statelessCursor(bucket) - if err != nil { - return nil, err - } - _, val, err = c.SeekExact(k) - return val, err -} - -func (tx *tx) Has(bucket string, k []byte) (bool, error) { - c, err := tx.statelessCursor(bucket) - if err != nil { - return false, err - } - kk, _, err := c.Seek(k) - if err != nil { - return false, err - } - return bytes.Equal(k, kk), nil -} - -func (c *remoteCursor) SeekExact(k []byte) (key, val []byte, err error) { - return c.seekExact(k) -} - -func (c *remoteCursor) Prev() ([]byte, []byte, error) { - return c.prev() -} - -func (tx *tx) Cursor(bucket string) (kv.Cursor, error) { - b := tx.db.buckets[bucket] - c := &remoteCursor{tx: tx, ctx: tx.ctx, bucketName: bucket, bucketCfg: b, stream: tx.stream} - tx.cursors = append(tx.cursors, c) - if err := c.stream.Send(&remote.Cursor{Op: remote.Op_OPEN, BucketName: c.bucketName}); err != nil { - return nil, err - } - msg, err := c.stream.Recv() - if err != nil { - return nil, err - } - c.id = msg.CursorId - return c, nil -} - -func (tx *tx) ListBuckets() ([]string, error) { - return nil, fmt.Errorf("function ListBuckets is not implemented for remoteTx") -} - -// func (c *remoteCursor) Put(k []byte, v []byte) error { panic("not supported") } -// func (c *remoteCursor) PutNoOverwrite(k []byte, v []byte) error { panic("not supported") } -// func (c *remoteCursor) Append(k []byte, v []byte) error { panic("not supported") } -// func (c *remoteCursor) Delete(k []byte) error { panic("not supported") } -// func (c *remoteCursor) DeleteCurrent() error { panic("not supported") } -func (c *remoteCursor) Count() (uint64, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_COUNT}); err != nil { - return 0, err - } - pair, err := c.stream.Recv() - if err != nil { - return 0, err - } - return binary.BigEndian.Uint64(pair.V), nil - -} - -func (c *remoteCursor) first() ([]byte, []byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_FIRST}); err != nil { - return []byte{}, nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return []byte{}, nil, err - } - return pair.K, pair.V, nil -} - -func (c *remoteCursor) next() ([]byte, []byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_NEXT}); err != nil { - return []byte{}, nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return []byte{}, nil, err - } - return pair.K, pair.V, nil -} -func (c *remoteCursor) nextDup() ([]byte, []byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_NEXT_DUP}); err != nil { - return []byte{}, nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return []byte{}, nil, err - } - return pair.K, pair.V, nil -} -func (c *remoteCursor) nextNoDup() ([]byte, []byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_NEXT_NO_DUP}); err != nil { - return []byte{}, nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return []byte{}, nil, err - } - return pair.K, pair.V, nil -} -func (c *remoteCursor) prev() ([]byte, []byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_PREV}); err != nil { - return []byte{}, nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return []byte{}, nil, err - } - return pair.K, pair.V, nil -} -func (c *remoteCursor) prevDup() ([]byte, []byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_PREV_DUP}); err != nil { - return []byte{}, nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return []byte{}, nil, err - } - return pair.K, pair.V, nil -} -func (c *remoteCursor) prevNoDup() ([]byte, []byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_PREV_NO_DUP}); err != nil { - return []byte{}, nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return []byte{}, nil, err - } - return pair.K, pair.V, nil -} -func (c *remoteCursor) last() ([]byte, []byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_LAST}); err != nil { - return []byte{}, nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return []byte{}, nil, err - } - return pair.K, pair.V, nil -} -func (c *remoteCursor) setRange(k []byte) ([]byte, []byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_SEEK, K: k}); err != nil { - return []byte{}, nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return []byte{}, nil, err - } - return pair.K, pair.V, nil -} -func (c *remoteCursor) seekExact(k []byte) ([]byte, []byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_SEEK_EXACT, K: k}); err != nil { - return []byte{}, nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return []byte{}, nil, err - } - return pair.K, pair.V, nil -} -func (c *remoteCursor) getBothRange(k, v []byte) ([]byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_SEEK_BOTH, K: k, V: v}); err != nil { - return nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return nil, err - } - return pair.V, nil -} -func (c *remoteCursor) seekBothExact(k, v []byte) ([]byte, []byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_SEEK_BOTH_EXACT, K: k, V: v}); err != nil { - return []byte{}, nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return []byte{}, nil, err - } - return pair.K, pair.V, nil -} -func (c *remoteCursor) firstDup() ([]byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_FIRST_DUP}); err != nil { - return nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return nil, err - } - return pair.V, nil -} -func (c *remoteCursor) lastDup() ([]byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_LAST_DUP}); err != nil { - return nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return nil, err - } - return pair.V, nil -} -func (c *remoteCursor) getCurrent() ([]byte, []byte, error) { - if err := c.stream.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_CURRENT}); err != nil { - return []byte{}, nil, err - } - pair, err := c.stream.Recv() - if err != nil { - return []byte{}, nil, err - } - return pair.K, pair.V, nil -} - -func (c *remoteCursor) Current() ([]byte, []byte, error) { - return c.getCurrent() -} - -// Seek - doesn't start streaming (because much of code does only several .Seek calls without reading sequence of data) -// .Next() - does request streaming (if configured by user) -func (c *remoteCursor) Seek(seek []byte) ([]byte, []byte, error) { - return c.setRange(seek) -} - -func (c *remoteCursor) First() ([]byte, []byte, error) { - return c.first() -} - -// Next - returns next data element from server, request streaming (if configured by user) -func (c *remoteCursor) Next() ([]byte, []byte, error) { - return c.next() -} - -func (c *remoteCursor) Last() ([]byte, []byte, error) { - return c.last() -} - -func (tx *tx) closeGrpcStream() { - if tx.stream == nil { - return - } - defer tx.streamCancelFn() // hard cancel stream if graceful wasn't successful - - if tx.streamingRequested { - // if streaming is in progress, can't use `CloseSend` - because - // server will not read it right not - it busy with streaming data - // TODO: set flag 'tx.streamingRequested' to false when got terminator from server (nil key or os.EOF) - tx.streamCancelFn() - } else { - // try graceful close stream - err := tx.stream.CloseSend() - if err != nil { - doLog := !grpcutil.IsEndOfStream(err) - if doLog { - log.Warn("couldn't send msg CloseSend to server", "err", err) - } - } else { - _, err = tx.stream.Recv() - if err != nil { - doLog := !grpcutil.IsEndOfStream(err) - if doLog { - log.Warn("received unexpected error from server after CloseSend", "err", err) - } - } - } - } - tx.stream = nil - tx.streamingRequested = false -} - -func (c *remoteCursor) Close() { - if c.stream == nil { - return - } - st := c.stream - c.stream = nil - if err := st.Send(&remote.Cursor{Cursor: c.id, Op: remote.Op_CLOSE}); err == nil { - _, _ = st.Recv() - } -} - -func (tx *tx) CursorDupSort(bucket string) (kv.CursorDupSort, error) { - b := tx.db.buckets[bucket] - c := &remoteCursor{tx: tx, ctx: tx.ctx, bucketName: bucket, bucketCfg: b, stream: tx.stream} - tx.cursors = append(tx.cursors, c) - if err := c.stream.Send(&remote.Cursor{Op: remote.Op_OPEN_DUP_SORT, BucketName: c.bucketName}); err != nil { - return nil, err - } - msg, err := c.stream.Recv() - if err != nil { - return nil, err - } - c.id = msg.CursorId - return &remoteCursorDupSort{remoteCursor: c}, nil -} - -func (c *remoteCursorDupSort) SeekBothExact(k, v []byte) ([]byte, []byte, error) { - return c.seekBothExact(k, v) -} - -func (c *remoteCursorDupSort) SeekBothRange(k, v []byte) ([]byte, error) { - return c.getBothRange(k, v) -} - -func (c *remoteCursorDupSort) DeleteExact(k1, k2 []byte) error { panic("not supported") } -func (c *remoteCursorDupSort) AppendDup(k []byte, v []byte) error { panic("not supported") } -func (c *remoteCursorDupSort) PutNoDupData(k, v []byte) error { panic("not supported") } -func (c *remoteCursorDupSort) DeleteCurrentDuplicates() error { panic("not supported") } -func (c *remoteCursorDupSort) CountDuplicates() (uint64, error) { panic("not supported") } - -func (c *remoteCursorDupSort) FirstDup() ([]byte, error) { return c.firstDup() } -func (c *remoteCursorDupSort) NextDup() ([]byte, []byte, error) { return c.nextDup() } -func (c *remoteCursorDupSort) NextNoDup() ([]byte, []byte, error) { return c.nextNoDup() } -func (c *remoteCursorDupSort) PrevDup() ([]byte, []byte, error) { return c.prevDup() } -func (c *remoteCursorDupSort) PrevNoDup() ([]byte, []byte, error) { return c.prevNoDup() } -func (c *remoteCursorDupSort) LastDup() ([]byte, error) { return c.lastDup() } - -// Temporal Methods -func (tx *tx) DomainGetAsOf(name kv.Domain, k, k2 []byte, ts uint64) (v []byte, ok bool, err error) { - reply, err := tx.db.remoteKV.DomainGet(tx.ctx, &remote.DomainGetReq{TxId: tx.id, Table: string(name), K: k, K2: k2, Ts: ts}) - if err != nil { - return nil, false, err - } - return reply.V, reply.Ok, nil -} - -func (tx *tx) DomainGet(name kv.Domain, k, k2 []byte) (v []byte, ok bool, err error) { - reply, err := tx.db.remoteKV.DomainGet(tx.ctx, &remote.DomainGetReq{TxId: tx.id, Table: string(name), K: k, K2: k2, Latest: true}) - if err != nil { - return nil, false, err - } - return reply.V, reply.Ok, nil -} - -func (tx *tx) DomainRange(name kv.Domain, fromKey, toKey []byte, ts uint64, asc order.By, limit int) (it iter.KV, err error) { - return iter.PaginateKV(func(pageToken string) (keys, vals [][]byte, nextPageToken string, err error) { - reply, err := tx.db.remoteKV.DomainRange(tx.ctx, &remote.DomainRangeReq{TxId: tx.id, Table: string(name), FromKey: fromKey, ToKey: toKey, Ts: ts, OrderAscend: bool(asc), Limit: int64(limit)}) - if err != nil { - return nil, nil, "", err - } - return reply.Keys, reply.Values, reply.NextPageToken, nil - }), nil -} -func (tx *tx) HistoryGet(name kv.History, k []byte, ts uint64) (v []byte, ok bool, err error) { - reply, err := tx.db.remoteKV.HistoryGet(tx.ctx, &remote.HistoryGetReq{TxId: tx.id, Table: string(name), K: k, Ts: ts}) - if err != nil { - return nil, false, err - } - return reply.V, reply.Ok, nil -} -func (tx *tx) HistoryRange(name kv.History, fromTs, toTs int, asc order.By, limit int) (it iter.KV, err error) { - return iter.PaginateKV(func(pageToken string) (keys, vals [][]byte, nextPageToken string, err error) { - reply, err := tx.db.remoteKV.HistoryRange(tx.ctx, &remote.HistoryRangeReq{TxId: tx.id, Table: string(name), FromTs: int64(fromTs), ToTs: int64(toTs), OrderAscend: bool(asc), Limit: int64(limit)}) - if err != nil { - return nil, nil, "", err - } - return reply.Keys, reply.Values, reply.NextPageToken, nil - }), nil -} - -func (tx *tx) IndexRange(name kv.InvertedIdx, k []byte, fromTs, toTs int, asc order.By, limit int) (timestamps iter.U64, err error) { - return iter.PaginateU64(func(pageToken string) (arr []uint64, nextPageToken string, err error) { - req := &remote.IndexRangeReq{TxId: tx.id, Table: string(name), K: k, FromTs: int64(fromTs), ToTs: int64(toTs), OrderAscend: bool(asc), Limit: int64(limit)} - reply, err := tx.db.remoteKV.IndexRange(tx.ctx, req) - if err != nil { - return nil, "", err - } - return reply.Timestamps, reply.NextPageToken, nil - }), nil -} - -func (tx *tx) Prefix(table string, prefix []byte) (iter.KV, error) { - nextPrefix, ok := kv.NextSubtree(prefix) - if !ok { - return tx.Range(table, prefix, nil) - } - return tx.Range(table, prefix, nextPrefix) -} - -func (tx *tx) rangeOrderLimit(table string, fromPrefix, toPrefix []byte, asc order.By, limit int) (iter.KV, error) { - return iter.PaginateKV(func(pageToken string) (keys [][]byte, values [][]byte, nextPageToken string, err error) { - req := &remote.RangeReq{TxId: tx.id, Table: table, FromPrefix: fromPrefix, ToPrefix: toPrefix, OrderAscend: bool(asc), Limit: int64(limit)} - reply, err := tx.db.remoteKV.Range(tx.ctx, req) - if err != nil { - return nil, nil, "", err - } - return reply.Keys, reply.Values, reply.NextPageToken, nil - }), nil -} -func (tx *tx) Range(table string, fromPrefix, toPrefix []byte) (iter.KV, error) { - return tx.rangeOrderLimit(table, fromPrefix, toPrefix, order.Asc, -1) -} -func (tx *tx) RangeAscend(table string, fromPrefix, toPrefix []byte, limit int) (iter.KV, error) { - return tx.rangeOrderLimit(table, fromPrefix, toPrefix, order.Asc, limit) -} -func (tx *tx) RangeDescend(table string, fromPrefix, toPrefix []byte, limit int) (iter.KV, error) { - return tx.rangeOrderLimit(table, fromPrefix, toPrefix, order.Desc, limit) -} -func (tx *tx) RangeDupSort(table string, key []byte, fromPrefix, toPrefix []byte, asc order.By, limit int) (iter.KV, error) { - panic("not implemented yet") -} From 4d606c38a77a2b0802f75fde1d7401456bcd6781 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 07:22:18 -0500 Subject: [PATCH 03/34] change interface --- rlp/encoder.go | 69 ++++++++++++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 30 deletions(-) diff --git a/rlp/encoder.go b/rlp/encoder.go index 3eb56b269..db938feab 100644 --- a/rlp/encoder.go +++ b/rlp/encoder.go @@ -2,6 +2,8 @@ package rlp import "golang.org/x/exp/constraints" +type EncoderFunc = func(i *Encoder) *Encoder + type Encoder struct { buf []byte } @@ -12,79 +14,86 @@ func NewEncoder(buf []byte) *Encoder { } } -func (e *Encoder) Bytes() []byte { +func (e *Encoder) Buffer() []byte { return e.buf } + func (e *Encoder) Reset(b []byte) { e.buf = b } -func (e *Encoder) WriteByte(p byte) (err error) { +func (e *Encoder) Write(p []byte) (n int, err error) { + e.Bytes(p) + return len(p), nil +} + +func (e *Encoder) Byte(p byte) *Encoder { e.buf = append(e.buf, p) - return nil + return e } -func (e *Encoder) Write(p []byte) (n int, err error) { +func (e *Encoder) Bytes(p []byte) *Encoder { e.buf = append(e.buf, p...) - return len(p), nil + return e } -func (e *Encoder) WriteString(str []byte) { +// Str will write a string correctly +func (e *Encoder) Str(str []byte) *Encoder { if len(str) > 55 { - e.WriteLongString(str) + return e.LongString(str) } - e.WriteShortString(str) + return e.ShortString(str) } -func (e *Encoder) WriteShortString(str []byte) { - e.WriteByte(TokenShortString.Plus(byte(len(str)))) - e.Write(str) - return +// String will assume your string is less than 56 bytes long, and do no validation as such +func (e *Encoder) ShortString(str []byte) *Encoder { + return e.Byte(TokenShortString.Plus(byte(len(str)))).Bytes(str) } -func (e *Encoder) WriteLongString(str []byte) { +// String will assume your string is greater than 55 bytes long, and do no validation as such +func (e *Encoder) LongString(str []byte) *Encoder { // write the indicator token - e.WriteByte(byte(TokenLongString)) + e.Byte(byte(TokenLongString)) // write the integer, knowing that we appended n bytes n := putUint(e, len(str)) // so we knw the indicator token was n+1 bytes ago. e.buf[len(e.buf)-(int(n)+1)] += n // and now add the actual length e.buf = append(e.buf, str...) - return + return e } -func (e *Encoder) WriteList(items ...func(i *Encoder)) { - e.writeList(true, items...) +func (e *Encoder) List(items ...EncoderFunc) *Encoder { + return e.writeList(true, items...) } -// WriteShortList will assume that your list payload is more than 55 bytes long -func (e *Encoder) WriteShortList(items ...func(i *Encoder)) { +// ShortList will assume that your list payload is less than 56 bytes long, and do no validation as such +func (e *Encoder) ShortList(items ...EncoderFunc) *Encoder { e.buf = append(e.buf, TokenShortList.Plus(byte(len(items)))) for _, v := range items { - v(e) + e = v(e) } - return + return e } -// WriteLongList will assume that your list payload is more than 55 bytes long, and do no validation as such -func (e *Encoder) WriteLongList(items ...func(i *Encoder)) { - e.writeList(false, items...) +// LongList will assume that your list payload is more than 55 bytes long, and do no validation as such +func (e *Encoder) LongList(items ...EncoderFunc) *Encoder { + return e.writeList(false, items...) } // writeList will first attempt to write a long list with the dat // if validate is false, it will just format it like the length is above 55 // if validate is true, it will format it like it is a shrot list -func (e *Encoder) writeList(validate bool, items ...func(i *Encoder)) { +func (e *Encoder) writeList(validate bool, items ...EncoderFunc) *Encoder { // write the indicator token - e.buf = append(e.buf, byte(TokenLongList)) + e = e.Byte(byte(TokenLongList)) // now pad 8 bytes - e.buf = append(e.buf, make([]byte, 8)...) + e = e.Bytes(make([]byte, 8)) // record the length before encoding items startLength := len(e.buf) // now write all the items for _, v := range items { - v(e) + e = v(e) } // the size is the difference in the lengths now dataSize := len(e.buf) - startLength @@ -96,7 +105,7 @@ func (e *Encoder) writeList(validate bool, items ...func(i *Encoder)) { // and now set the new size e.buf = e.buf[:startLength+dataSize-8] // we are done, return - return + return e } // ok, so it's a long string. // create a new encoder centered at startLength - 8 @@ -114,7 +123,7 @@ func (e *Encoder) writeList(validate bool, items ...func(i *Encoder)) { // set the new length e.buf = e.buf[:startLength-shift+dataSize] } - return + return e } func putUint[T constraints.Integer](e *Encoder, t T) (size byte) { From 0f33aeab90db38681883fcd747a8452ecc195651 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 08:36:42 -0500 Subject: [PATCH 04/34] ok --- rlp/decoder.go | 89 ++++++++++++++++++++++++++++++++++++++++++++++++++ rlp/encoder.go | 16 +++------ rlp/parse.go | 9 ++--- rlp/util.go | 5 ++- 4 files changed, 102 insertions(+), 17 deletions(-) create mode 100644 rlp/decoder.go diff --git a/rlp/decoder.go b/rlp/decoder.go new file mode 100644 index 000000000..21cb61875 --- /dev/null +++ b/rlp/decoder.go @@ -0,0 +1,89 @@ +package rlp + +import ( + "fmt" + "reflect" +) + +func unmarshal(data []byte, val any) error { + rv := reflect.ValueOf(val) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return fmt.Errorf("%w: v must be ptr", ErrDecode) + } + + v := rv.Elem() + + // read the first byte + if len(data) == 0 { + return ErrUnexpectedEOF + } + + // figure out what we are reading + token := identifyToken(data[0]) + + // switch + switch token { + case TokenDecimal: + // in this case, the value is just the byte itself + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.SetInt(int64(data[0])) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + v.SetUint(uint64(data[0])) + default: + return fmt.Errorf("%w: decimal must be unmarshal into integer type", ErrDecode) + } + case TokenShortString: + sz := int(token.Diff(data[0])) + if len(data) <= 1+sz { + return ErrUnexpectedEOF + } + dat := data[1 : 1+sz] + switch v.Kind() { + case reflect.String: + v.SetString(string(dat)) + case reflect.Slice: + if v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("%w: need to use uint8 as underlying if want slice output from shortstring", ErrDecode) + } + v.SetBytes(dat) + case reflect.Array: + if v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("%w: need to use uint8 as underlying if want array output from shortstring", ErrDecode) + } + reflect.Copy(v, reflect.ValueOf(dat)) + } + case TokenLongString: + lenSz := int(token.Diff(data[0])) + if len(data) <= 1+lenSz { + return ErrUnexpectedEOF + } + sz, err := BeInt(data, 1, lenSz) + if err != nil { + return err + } + if len(data) <= 1+sz { + return ErrUnexpectedEOF + } + dat := data[1+lenSz : 1+sz+lenSz] + switch v.Kind() { + case reflect.String: + v.SetString(string(dat)) + case reflect.Slice: + if v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("%w: need to use uint8 as underlying if want slice output from longstring", ErrDecode) + } + v.SetBytes(dat) + case reflect.Array: + if v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("%w: need to use uint8 as underlying if want array output from longstring", ErrDecode) + } + reflect.Copy(v, reflect.ValueOf(dat)) + } + case TokenShortList: + case TokenLongList: + case TokenUnknown: + return fmt.Errorf("%w: unknown token", ErrDecode) + } + return nil +} diff --git a/rlp/encoder.go b/rlp/encoder.go index db938feab..e042ae5c7 100644 --- a/rlp/encoder.go +++ b/rlp/encoder.go @@ -14,19 +14,11 @@ func NewEncoder(buf []byte) *Encoder { } } +// Buffer returns the underlying buffer func (e *Encoder) Buffer() []byte { return e.buf } -func (e *Encoder) Reset(b []byte) { - e.buf = b -} - -func (e *Encoder) Write(p []byte) (n int, err error) { - e.Bytes(p) - return len(p), nil -} - func (e *Encoder) Byte(p byte) *Encoder { e.buf = append(e.buf, p) return e @@ -99,7 +91,7 @@ func (e *Encoder) writeList(validate bool, items ...EncoderFunc) *Encoder { dataSize := len(e.buf) - startLength if dataSize <= 55 && validate { // oh it's actually a short string! awkward. let's set that then. - e.buf[startLength-9] = TokenShortList.Plus(byte(len(items))) + e.buf[startLength-8-1] = TokenShortList.Plus(byte(len(items))) // and then copy the data over copy(e.buf[startLength-8:], e.buf[startLength:startLength+dataSize]) // and now set the new size @@ -112,8 +104,8 @@ func (e *Encoder) writeList(validate bool, items ...EncoderFunc) *Encoder { enc := NewEncoder(e.buf[startLength-8:]) // now write using that encoder the size n := putUint(enc, dataSize) - // and update the token, which we know is at startLength-9 - e.buf[startLength-9] += n + // and update the token, which we know is at startLength-8-1 + e.buf[startLength-8-1] += n // the shift to perform now is 8 - n. shift := int(8 - n) // if there is a positive shift, then we must perform the shift diff --git a/rlp/parse.go b/rlp/parse.go index cbe59749a..23ffc4637 100644 --- a/rlp/parse.go +++ b/rlp/parse.go @@ -26,9 +26,10 @@ import ( ) var ( - ErrBase = fmt.Errorf("rlp") - ErrParse = fmt.Errorf("%w parse", ErrBase) - ErrDecode = fmt.Errorf("%w decode", ErrBase) + ErrBase = fmt.Errorf("rlp") + ErrParse = fmt.Errorf("%w parse", ErrBase) + ErrDecode = fmt.Errorf("%w decode", ErrBase) + ErrUnexpectedEOF = fmt.Errorf("%w EOF", ErrBase) ) func IsRLPError(err error) bool { return errors.Is(err, ErrBase) } @@ -37,7 +38,7 @@ func IsRLPError(err error) bool { return errors.Is(err, ErrBase) } func BeInt(payload []byte, pos, length int) (int, error) { var r int if pos+length >= len(payload) { - return 0, fmt.Errorf("%w: unexpected end of payload", ErrParse) + return 0, ErrUnexpectedEOF } if length > 0 && payload[pos] == 0 { return 0, fmt.Errorf("%w: integer encoding for RLP must not have leading zeros: %x", ErrParse, payload[pos:pos+length]) diff --git a/rlp/util.go b/rlp/util.go index 784c297d5..0ba45efa0 100644 --- a/rlp/util.go +++ b/rlp/util.go @@ -6,7 +6,10 @@ func (T Token) Plus(n byte) byte { return byte(T) + n } -// This token table can also be used for offsets. how cool! +func (T Token) Diff(n byte) byte { + return n - byte(T) +} + const ( TokenDecimal Token = 0x00 TokenShortString Token = 0x80 From 50d979e035dd5a4974f0d62536d4f5852c7720a7 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 08:38:00 -0500 Subject: [PATCH 05/34] ok --- rlp/encoder.go | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/rlp/encoder.go b/rlp/encoder.go index e042ae5c7..972286e4b 100644 --- a/rlp/encoder.go +++ b/rlp/encoder.go @@ -55,17 +55,14 @@ func (e *Encoder) LongString(str []byte) *Encoder { return e } +// List will attempt to write the list of encoder funcs to the buf func (e *Encoder) List(items ...EncoderFunc) *Encoder { return e.writeList(true, items...) } -// ShortList will assume that your list payload is less than 56 bytes long, and do no validation as such +// ShortList actually calls List func (e *Encoder) ShortList(items ...EncoderFunc) *Encoder { - e.buf = append(e.buf, TokenShortList.Plus(byte(len(items)))) - for _, v := range items { - e = v(e) - } - return e + return e.writeList(true, items...) } // LongList will assume that your list payload is more than 55 bytes long, and do no validation as such @@ -91,7 +88,7 @@ func (e *Encoder) writeList(validate bool, items ...EncoderFunc) *Encoder { dataSize := len(e.buf) - startLength if dataSize <= 55 && validate { // oh it's actually a short string! awkward. let's set that then. - e.buf[startLength-8-1] = TokenShortList.Plus(byte(len(items))) + e.buf[startLength-8-1] = TokenShortList.Plus(byte(dataSize)) // and then copy the data over copy(e.buf[startLength-8:], e.buf[startLength:startLength+dataSize]) // and now set the new size From 0709dfc4486cfa0fb94efb448b59a60374c1cbdf Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 08:49:24 -0500 Subject: [PATCH 06/34] wip --- rlp/decoder.go | 93 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 28 deletions(-) diff --git a/rlp/decoder.go b/rlp/decoder.go index 21cb61875..8cb8d7c49 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -39,20 +39,7 @@ func unmarshal(data []byte, val any) error { return ErrUnexpectedEOF } dat := data[1 : 1+sz] - switch v.Kind() { - case reflect.String: - v.SetString(string(dat)) - case reflect.Slice: - if v.Type().Elem().Kind() != reflect.Uint8 { - return fmt.Errorf("%w: need to use uint8 as underlying if want slice output from shortstring", ErrDecode) - } - v.SetBytes(dat) - case reflect.Array: - if v.Type().Elem().Kind() != reflect.Uint8 { - return fmt.Errorf("%w: need to use uint8 as underlying if want array output from shortstring", ErrDecode) - } - reflect.Copy(v, reflect.ValueOf(dat)) - } + return reflectString(dat, v, rv) case TokenLongString: lenSz := int(token.Diff(data[0])) if len(data) <= 1+lenSz { @@ -66,24 +53,74 @@ func unmarshal(data []byte, val any) error { return ErrUnexpectedEOF } dat := data[1+lenSz : 1+sz+lenSz] - switch v.Kind() { - case reflect.String: - v.SetString(string(dat)) - case reflect.Slice: - if v.Type().Elem().Kind() != reflect.Uint8 { - return fmt.Errorf("%w: need to use uint8 as underlying if want slice output from longstring", ErrDecode) - } - v.SetBytes(dat) - case reflect.Array: - if v.Type().Elem().Kind() != reflect.Uint8 { - return fmt.Errorf("%w: need to use uint8 as underlying if want array output from longstring", ErrDecode) - } - reflect.Copy(v, reflect.ValueOf(dat)) - } + return reflectString(dat, v, rv) case TokenShortList: + sz := int(token.Diff(data[0])) + if len(data) <= 1+sz { + return ErrUnexpectedEOF + } + dat := data[1 : 1+sz] + return reflectList(dat, v, rv) case TokenLongList: + lenSz := int(token.Diff(data[0])) + if len(data) <= 1+lenSz { + return ErrUnexpectedEOF + } + sz, err := BeInt(data, 1, lenSz) + if err != nil { + return err + } + if len(data) <= 1+sz { + return ErrUnexpectedEOF + } + dat := data[1+lenSz : 1+sz+lenSz] + return reflectList(dat, v, rv) case TokenUnknown: return fmt.Errorf("%w: unknown token", ErrDecode) } return nil } + +func reflectString(dat []byte, v reflect.Value, rv reflect.Value) error { + switch v.Kind() { + case reflect.String: + v.SetString(string(dat)) + case reflect.Slice: + if v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("%w: need to use uint8 as underlying if want slice output from longstring", ErrDecode) + } + v.SetBytes(dat) + case reflect.Array: + if v.Type().Elem().Kind() != reflect.Uint8 { + return fmt.Errorf("%w: need to use uint8 as underlying if want array output from longstring", ErrDecode) + } + reflect.Copy(v, reflect.ValueOf(dat)) + } + + return nil +} + +func reflectList(dat []byte, v reflect.Value, rv reflect.Value) error { + switch v.Kind() { + case reflect.Map: + // TODO: read two elements. + rv1 := reflect.New(v.Type().Key()) + v1 := rv1.Elem() + err := reflectString(dat, v1, rv1) + if err != nil { + return err + } + //TODO: need to advance dat cursor - create helper class + rv2 := reflect.New(v.Type().Elem()) + v2 := rv1.Elem() + err = reflectString(dat, v2, rv2) + if err != nil { + return err + } + case reflect.Array: + // TODO: read up to N elements + case reflect.Slice: + // TODO: read all elements into slice, creating more if needed + } + return nil +} From 4cdb4d360eadb0014decefc0dfbfbb6b41b38e9f Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 20:04:02 -0500 Subject: [PATCH 07/34] more --- rlp/decoder.go | 161 ++++++++++++++++++++++++++++++-------------- rlp/decoder_test.go | 7 ++ 2 files changed, 116 insertions(+), 52 deletions(-) create mode 100644 rlp/decoder_test.go diff --git a/rlp/decoder.go b/rlp/decoder.go index 8cb8d7c49..772e4fa0a 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -1,126 +1,183 @@ package rlp import ( + "bytes" "fmt" "reflect" ) -func unmarshal(data []byte, val any) error { +func Unmarshal(data []byte, val any) error { + buf := bytes.NewBuffer(data) + return unmarshal(buf, val) + +} +func unmarshal(buf *bytes.Buffer, val any) error { rv := reflect.ValueOf(val) if rv.Kind() != reflect.Pointer || rv.IsNil() { return fmt.Errorf("%w: v must be ptr", ErrDecode) } - v := rv.Elem() + err := reflectAny(buf, v, rv) + if err != nil { + return fmt.Errorf("%w: %w", ErrDecode, err) + } + return nil +} + +func nextFull(dat *bytes.Buffer, size int) ([]byte, error) { + d := dat.Next(size) + if len(d) != size { + return nil, ErrUnexpectedEOF + } + return d, nil +} - // read the first byte - if len(data) == 0 { - return ErrUnexpectedEOF +// BeInt parses Big Endian representation of an integer from given payload at given position +func decodeBeInt(w *bytes.Buffer, length int) (int, error) { + var r int + dat, err := nextFull(w, length) + if err != nil { + return 0, ErrUnexpectedEOF + } + if length > 0 && dat[0] == 0 { + return 0, fmt.Errorf("%w: integer encoding for RLP must not have leading zeros: %x", ErrParse, dat) + } + for _, b := range dat[0:length] { + r = (r << 8) | int(b) } + return r, nil +} +func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { // figure out what we are reading - token := identifyToken(data[0]) - + prefix, err := w.ReadByte() + if err != nil { + return err + } + token := identifyToken(prefix) // switch switch token { case TokenDecimal: // in this case, the value is just the byte itself switch v.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - v.SetInt(int64(data[0])) + v.SetInt(int64(prefix)) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - v.SetUint(uint64(data[0])) + v.SetUint(uint64(prefix)) + case reflect.Invalid: + // do nothing default: return fmt.Errorf("%w: decimal must be unmarshal into integer type", ErrDecode) } case TokenShortString: - sz := int(token.Diff(data[0])) - if len(data) <= 1+sz { - return ErrUnexpectedEOF + sz := int(token.Diff(prefix)) + str, err := nextFull(w, sz) + if err != nil { + return err } - dat := data[1 : 1+sz] - return reflectString(dat, v, rv) + return putString(str, v, rv) case TokenLongString: - lenSz := int(token.Diff(data[0])) - if len(data) <= 1+lenSz { - return ErrUnexpectedEOF - } - sz, err := BeInt(data, 1, lenSz) + lenSz := int(token.Diff(prefix)) + sz, err := decodeBeInt(w, lenSz) if err != nil { return err } - if len(data) <= 1+sz { - return ErrUnexpectedEOF + str, err := nextFull(w, sz) + if err != nil { + return err } - dat := data[1+lenSz : 1+sz+lenSz] - return reflectString(dat, v, rv) + return putString(str, v, rv) case TokenShortList: - sz := int(token.Diff(data[0])) - if len(data) <= 1+sz { - return ErrUnexpectedEOF + sz := int(token.Diff(prefix)) + buf, err := nextFull(w, sz) + if err != nil { + return err } - dat := data[1 : 1+sz] - return reflectList(dat, v, rv) + return reflectList(bytes.NewBuffer(buf), v, rv) case TokenLongList: - lenSz := int(token.Diff(data[0])) - if len(data) <= 1+lenSz { - return ErrUnexpectedEOF - } - sz, err := BeInt(data, 1, lenSz) + lenSz := int(token.Diff(prefix)) + sz, err := decodeBeInt(w, lenSz) if err != nil { return err } - if len(data) <= 1+sz { - return ErrUnexpectedEOF + buf, err := nextFull(w, sz) + if err != nil { + return err } - dat := data[1+lenSz : 1+sz+lenSz] - return reflectList(dat, v, rv) + return reflectList(bytes.NewBuffer(buf), v, rv) case TokenUnknown: return fmt.Errorf("%w: unknown token", ErrDecode) } return nil } -func reflectString(dat []byte, v reflect.Value, rv reflect.Value) error { +func putString(w []byte, v reflect.Value, rv reflect.Value) error { switch v.Kind() { case reflect.String: - v.SetString(string(dat)) + v.SetString(string(w)) case reflect.Slice: if v.Type().Elem().Kind() != reflect.Uint8 { return fmt.Errorf("%w: need to use uint8 as underlying if want slice output from longstring", ErrDecode) } - v.SetBytes(dat) + v.SetBytes(w) case reflect.Array: if v.Type().Elem().Kind() != reflect.Uint8 { return fmt.Errorf("%w: need to use uint8 as underlying if want array output from longstring", ErrDecode) } - reflect.Copy(v, reflect.ValueOf(dat)) + reflect.Copy(v, reflect.ValueOf(w)) + case reflect.Invalid: + // do nothing + return nil } - return nil } -func reflectList(dat []byte, v reflect.Value, rv reflect.Value) error { +func reflectList(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { switch v.Kind() { + case reflect.Invalid: + // do nothing + return nil case reflect.Map: - // TODO: read two elements. rv1 := reflect.New(v.Type().Key()) v1 := rv1.Elem() - err := reflectString(dat, v1, rv1) + err := reflectAny(w, v1, rv1) if err != nil { return err } - //TODO: need to advance dat cursor - create helper class rv2 := reflect.New(v.Type().Elem()) - v2 := rv1.Elem() - err = reflectString(dat, v2, rv2) + v2 := rv2.Elem() + err = reflectAny(w, v2, rv2) if err != nil { return err } - case reflect.Array: - // TODO: read up to N elements - case reflect.Slice: - // TODO: read all elements into slice, creating more if needed + v.SetMapIndex(rv1, rv2) + case reflect.Array, reflect.Slice: + idx := 0 + for { + if idx >= v.Cap() { + v.Grow(1) + } + if idx >= v.Len() { + v.SetLen(idx + 1) + } + if idx < v.Len() { + // Decode into element. + rv1 := v.Index(idx) + v1 := rv1.Elem() + err := reflectAny(w, v1, rv1) + if err != nil { + return err + } + } else { + // Ran out of fixed array: skip. + rv1 := reflect.Value{} + err := reflectAny(w, rv1, rv1) + if err != nil { + return err + } + } + idx++ + } } return nil } diff --git a/rlp/decoder_test.go b/rlp/decoder_test.go new file mode 100644 index 000000000..02cb825d3 --- /dev/null +++ b/rlp/decoder_test.go @@ -0,0 +1,7 @@ +package rlp_test + +import "testing" + +func TestHelloWorld(t *testing.T) { + // t.Fatal("not implemented") +} From 3f577ee3e21baff0c957c7dc96eb3fbafd8c6f87 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 20:15:50 -0500 Subject: [PATCH 08/34] wip --- rlp/decoder.go | 12 ++++++++++++ rlp/decoder_test.go | 33 ++++++++++++++++++++++++++++++--- rlp/parse.go | 2 +- 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/rlp/decoder.go b/rlp/decoder.go index 772e4fa0a..84d0f03e1 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -125,6 +125,18 @@ func putString(w []byte, v reflect.Value, rv reflect.Value) error { return fmt.Errorf("%w: need to use uint8 as underlying if want array output from longstring", ErrDecode) } reflect.Copy(v, reflect.ValueOf(w)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + val, err := BeInt(w, 0, len(w)) + if err != nil { + return err + } + v.SetInt(int64(val)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + val, err := BeInt(w, 0, len(w)) + if err != nil { + return err + } + v.SetUint(uint64(val)) case reflect.Invalid: // do nothing return nil diff --git a/rlp/decoder_test.go b/rlp/decoder_test.go index 02cb825d3..fc556d9f3 100644 --- a/rlp/decoder_test.go +++ b/rlp/decoder_test.go @@ -1,7 +1,34 @@ package rlp_test -import "testing" +import ( + "testing" -func TestHelloWorld(t *testing.T) { - // t.Fatal("not implemented") + "github.com/ledgerwatch/erigon-lib/rlp" + "github.com/stretchr/testify/require" +) + +func TestDecoder(t *testing.T) { + t.Run("ShortString", func(t *testing.T) { + t.Run("ToString", func(t *testing.T) { + bts := []byte{0x83, 'd', 'o', 'g'} + var s string + err := rlp.Unmarshal(bts, &s) + require.NoError(t, err) + require.EqualValues(t, "dog", s) + }) + t.Run("ToBytes", func(t *testing.T) { + bts := []byte{0x83, 'd', 'o', 'g'} + var s []byte + err := rlp.Unmarshal(bts, &s) + require.NoError(t, err) + require.EqualValues(t, []byte("dog"), s) + }) + t.Run("ToInt", func(t *testing.T) { + bts := []byte{0x82, 0x04, 0x00} + var s int + err := rlp.Unmarshal(bts, &s) + require.NoError(t, err) + require.EqualValues(t, 1024, s) + }) + }) } diff --git a/rlp/parse.go b/rlp/parse.go index 23ffc4637..449277f09 100644 --- a/rlp/parse.go +++ b/rlp/parse.go @@ -37,7 +37,7 @@ func IsRLPError(err error) bool { return errors.Is(err, ErrBase) } // BeInt parses Big Endian representation of an integer from given payload at given position func BeInt(payload []byte, pos, length int) (int, error) { var r int - if pos+length >= len(payload) { + if pos+length > len(payload) { return 0, ErrUnexpectedEOF } if length > 0 && payload[pos] == 0 { From 604a625f1148d3c8a0f81bba1de4ec622121004d Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 20:43:14 -0500 Subject: [PATCH 09/34] rename --- rlp/{decoder.go => unmarshaler.go} | 8 ++++++++ rlp/{decoder_test.go => unmarshaler_test.go} | 0 2 files changed, 8 insertions(+) rename rlp/{decoder.go => unmarshaler.go} (97%) rename rlp/{decoder_test.go => unmarshaler_test.go} (100%) diff --git a/rlp/decoder.go b/rlp/unmarshaler.go similarity index 97% rename from rlp/decoder.go rename to rlp/unmarshaler.go index 84d0f03e1..2befb3551 100644 --- a/rlp/decoder.go +++ b/rlp/unmarshaler.go @@ -6,6 +6,14 @@ import ( "reflect" ) +type Unmarshaler interface { + UnmarshalRLP(data []byte) error +} + +type Marshaler interface { + MarshalRLP() ([]byte, error) +} + func Unmarshal(data []byte, val any) error { buf := bytes.NewBuffer(data) return unmarshal(buf, val) diff --git a/rlp/decoder_test.go b/rlp/unmarshaler_test.go similarity index 100% rename from rlp/decoder_test.go rename to rlp/unmarshaler_test.go From 68a9149e6826a29a8b6193bbb073337c062cb536 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 20:58:02 -0500 Subject: [PATCH 10/34] struct --- rlp/decoder.go | 4 ++++ rlp/unmarshaler.go | 13 +++++++++++++ rlp/unmarshaler_test.go | 13 +++++++++++++ 3 files changed, 30 insertions(+) create mode 100644 rlp/decoder.go diff --git a/rlp/decoder.go b/rlp/decoder.go new file mode 100644 index 000000000..a065e8479 --- /dev/null +++ b/rlp/decoder.go @@ -0,0 +1,4 @@ +package rlp + +type Decoder struct { +} diff --git a/rlp/unmarshaler.go b/rlp/unmarshaler.go index 2befb3551..6eae84466 100644 --- a/rlp/unmarshaler.go +++ b/rlp/unmarshaler.go @@ -171,6 +171,19 @@ func reflectList(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { return err } v.SetMapIndex(rv1, rv2) + case reflect.Struct: + for idx := 0; idx < v.NumField(); idx++ { + // Decode into element. + rv1 := v.Field(idx).Addr() + v1 := rv1.Elem() + shouldSet := v1.CanSet() + if shouldSet { + err := reflectAny(w, v1, rv1) + if err != nil { + return err + } + } + } case reflect.Array, reflect.Slice: idx := 0 for { diff --git a/rlp/unmarshaler_test.go b/rlp/unmarshaler_test.go index fc556d9f3..1e2f8145a 100644 --- a/rlp/unmarshaler_test.go +++ b/rlp/unmarshaler_test.go @@ -8,6 +8,12 @@ import ( ) func TestDecoder(t *testing.T) { + + type simple struct { + Key string + Value string + } + t.Run("ShortString", func(t *testing.T) { t.Run("ToString", func(t *testing.T) { bts := []byte{0x83, 'd', 'o', 'g'} @@ -30,5 +36,12 @@ func TestDecoder(t *testing.T) { require.NoError(t, err) require.EqualValues(t, 1024, s) }) + t.Run("ToSimpleStruct", func(t *testing.T) { + bts := []byte{0xc8, 0x83, 'c', 'a', 't', 0x83, 'd', 'o', 'g'} + var s simple + err := rlp.Unmarshal(bts, &s) + require.NoError(t, err) + require.EqualValues(t, simple{Key: "cat", Value: "dog"}, s) + }) }) } From fb28f27fc2e11e88c3233a629b1bb350d366b4cb Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 20:59:51 -0500 Subject: [PATCH 11/34] ok --- rlp/unmarshaler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rlp/unmarshaler.go b/rlp/unmarshaler.go index 6eae84466..601be3555 100644 --- a/rlp/unmarshaler.go +++ b/rlp/unmarshaler.go @@ -17,8 +17,8 @@ type Marshaler interface { func Unmarshal(data []byte, val any) error { buf := bytes.NewBuffer(data) return unmarshal(buf, val) - } + func unmarshal(buf *bytes.Buffer, val any) error { rv := reflect.ValueOf(val) if rv.Kind() != reflect.Pointer || rv.IsNil() { From 404bf1f9536160131ae8cdac91404941da87a0ed Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 21:01:59 -0500 Subject: [PATCH 12/34] ok --- rlp/unmarshaler.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rlp/unmarshaler.go b/rlp/unmarshaler.go index 601be3555..2fe4fcda3 100644 --- a/rlp/unmarshaler.go +++ b/rlp/unmarshaler.go @@ -175,8 +175,9 @@ func reflectList(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { for idx := 0; idx < v.NumField(); idx++ { // Decode into element. rv1 := v.Field(idx).Addr() + rt1 := v.Type().Field(idx) v1 := rv1.Elem() - shouldSet := v1.CanSet() + shouldSet := rt1.IsExported() if shouldSet { err := reflectAny(w, v1, rv1) if err != nil { From 52ccff62edac0467a0c1cd7bb460adb1a5de0429 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 21:09:11 -0500 Subject: [PATCH 13/34] unmarshaler --- rlp/unmarshaler.go | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/rlp/unmarshaler.go b/rlp/unmarshaler.go index 2fe4fcda3..7b2dd5411 100644 --- a/rlp/unmarshaler.go +++ b/rlp/unmarshaler.go @@ -10,10 +10,6 @@ type Unmarshaler interface { UnmarshalRLP(data []byte) error } -type Marshaler interface { - MarshalRLP() ([]byte, error) -} - func Unmarshal(data []byte, val any) error { buf := bytes.NewBuffer(data) return unmarshal(buf, val) @@ -66,6 +62,9 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { // switch switch token { case TokenDecimal: + if um, ok := rv.Interface().(Unmarshaler); ok { + return um.UnmarshalRLP([]byte{prefix}) + } // in this case, the value is just the byte itself switch v.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -83,6 +82,9 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { if err != nil { return err } + if um, ok := rv.Interface().(Unmarshaler); ok { + return um.UnmarshalRLP(str) + } return putString(str, v, rv) case TokenLongString: lenSz := int(token.Diff(prefix)) @@ -94,6 +96,9 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { if err != nil { return err } + if um, ok := rv.Interface().(Unmarshaler); ok { + return um.UnmarshalRLP(str) + } return putString(str, v, rv) case TokenShortList: sz := int(token.Diff(prefix)) @@ -101,6 +106,9 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { if err != nil { return err } + if um, ok := rv.Interface().(Unmarshaler); ok { + return um.UnmarshalRLP(buf) + } return reflectList(bytes.NewBuffer(buf), v, rv) case TokenLongList: lenSz := int(token.Diff(prefix)) @@ -112,6 +120,9 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { if err != nil { return err } + if um, ok := rv.Interface().(Unmarshaler); ok { + return um.UnmarshalRLP(buf) + } return reflectList(bytes.NewBuffer(buf), v, rv) case TokenUnknown: return fmt.Errorf("%w: unknown token", ErrDecode) @@ -120,6 +131,9 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { } func putString(w []byte, v reflect.Value, rv reflect.Value) error { + if um, ok := rv.Interface().(Unmarshaler); ok { + return um.UnmarshalRLP(w) + } switch v.Kind() { case reflect.String: v.SetString(string(w)) @@ -153,6 +167,9 @@ func putString(w []byte, v reflect.Value, rv reflect.Value) error { } func reflectList(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { + if um, ok := rv.Interface().(Unmarshaler); ok { + return um.UnmarshalRLP(w.Bytes()) + } switch v.Kind() { case reflect.Invalid: // do nothing From d7488c09da641fc37a006c8e94bf8179907cf6c3 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 21:16:55 -0500 Subject: [PATCH 14/34] ok --- rlp/unmarshaler.go | 24 +++--------------------- rlp/unmarshaler_test.go | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/rlp/unmarshaler.go b/rlp/unmarshaler.go index 7b2dd5411..6ccca7635 100644 --- a/rlp/unmarshaler.go +++ b/rlp/unmarshaler.go @@ -53,6 +53,9 @@ func decodeBeInt(w *bytes.Buffer, length int) (int, error) { } func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { + if um, ok := rv.Interface().(Unmarshaler); ok { + return um.UnmarshalRLP(w.Bytes()) + } // figure out what we are reading prefix, err := w.ReadByte() if err != nil { @@ -62,9 +65,6 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { // switch switch token { case TokenDecimal: - if um, ok := rv.Interface().(Unmarshaler); ok { - return um.UnmarshalRLP([]byte{prefix}) - } // in this case, the value is just the byte itself switch v.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: @@ -82,9 +82,6 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { if err != nil { return err } - if um, ok := rv.Interface().(Unmarshaler); ok { - return um.UnmarshalRLP(str) - } return putString(str, v, rv) case TokenLongString: lenSz := int(token.Diff(prefix)) @@ -96,9 +93,6 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { if err != nil { return err } - if um, ok := rv.Interface().(Unmarshaler); ok { - return um.UnmarshalRLP(str) - } return putString(str, v, rv) case TokenShortList: sz := int(token.Diff(prefix)) @@ -106,9 +100,6 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { if err != nil { return err } - if um, ok := rv.Interface().(Unmarshaler); ok { - return um.UnmarshalRLP(buf) - } return reflectList(bytes.NewBuffer(buf), v, rv) case TokenLongList: lenSz := int(token.Diff(prefix)) @@ -120,9 +111,6 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { if err != nil { return err } - if um, ok := rv.Interface().(Unmarshaler); ok { - return um.UnmarshalRLP(buf) - } return reflectList(bytes.NewBuffer(buf), v, rv) case TokenUnknown: return fmt.Errorf("%w: unknown token", ErrDecode) @@ -131,9 +119,6 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { } func putString(w []byte, v reflect.Value, rv reflect.Value) error { - if um, ok := rv.Interface().(Unmarshaler); ok { - return um.UnmarshalRLP(w) - } switch v.Kind() { case reflect.String: v.SetString(string(w)) @@ -167,9 +152,6 @@ func putString(w []byte, v reflect.Value, rv reflect.Value) error { } func reflectList(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { - if um, ok := rv.Interface().(Unmarshaler); ok { - return um.UnmarshalRLP(w.Bytes()) - } switch v.Kind() { case reflect.Invalid: // do nothing diff --git a/rlp/unmarshaler_test.go b/rlp/unmarshaler_test.go index 1e2f8145a..bd0be6a87 100644 --- a/rlp/unmarshaler_test.go +++ b/rlp/unmarshaler_test.go @@ -7,6 +7,18 @@ import ( "github.com/stretchr/testify/require" ) +type plusOne int + +func (p *plusOne) UnmarshalRLP(data []byte) error { + var s int + err := rlp.Unmarshal(data, &s) + if err != nil { + return err + } + (*p) = plusOne(s + 1) + return nil +} + func TestDecoder(t *testing.T) { type simple struct { @@ -36,6 +48,13 @@ func TestDecoder(t *testing.T) { require.NoError(t, err) require.EqualValues(t, 1024, s) }) + t.Run("ToIntUnmarshaler", func(t *testing.T) { + bts := []byte{0x82, 0x04, 0x00} + var s plusOne + err := rlp.Unmarshal(bts, &s) + require.NoError(t, err) + require.EqualValues(t, plusOne(1025), s) + }) t.Run("ToSimpleStruct", func(t *testing.T) { bts := []byte{0xc8, 0x83, 'c', 'a', 't', 0x83, 'd', 'o', 'g'} var s simple From 589bfbf621df83b5052b1905349435e41a30732c Mon Sep 17 00:00:00 2001 From: a Date: Thu, 31 Aug 2023 22:18:16 -0500 Subject: [PATCH 15/34] weird thing i made --- rlp/decoder.go | 99 ++++++++++++++++++++++++++++++++++++++++++++++ rlp/encoder.go | 4 +- rlp/readme.md | 11 ++++++ rlp/unmarshaler.go | 38 ++++-------------- rlp/util.go | 39 +++++++++++++----- 5 files changed, 149 insertions(+), 42 deletions(-) create mode 100644 rlp/readme.md diff --git a/rlp/decoder.go b/rlp/decoder.go index a065e8479..0a2ebe41a 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -1,4 +1,103 @@ package rlp +import ( + "bytes" + "fmt" +) + type Decoder struct { + *bytes.Buffer +} + +func NewDecoder(buf []byte) *Decoder { + return &Decoder{ + Buffer: bytes.NewBuffer(buf), + } +} + +func (d *Decoder) List() (*Decoder, error) { + w := d.Buffer + // figure out what we are reading + prefix, err := w.ReadByte() + if err != nil { + return nil, err + } + token := identifyToken(prefix) + // switch + switch token { + case TokenShortList: + sz := int(token.Diff(prefix)) + buf, err := nextFull(w, sz) + if err != nil { + return nil, err + } + return NewDecoder(buf), nil + case TokenLongList: + lenSz := int(token.Diff(prefix)) + sz, err := nextBeInt(w, lenSz) + if err != nil { + return nil, err + } + buf, err := nextFull(w, sz) + if err != nil { + return nil, err + } + return NewDecoder(buf), nil + default: + return nil, fmt.Errorf("%w: List on non-list token", ErrDecode) + } +} + +func DecodeBlob[T any](fn func(*T, []byte) error, receiver *T) func(d *Decoder) error { + return func(d *Decoder) error { + // figure out what we are reading + prefix, err := d.ReadByte() + if err != nil { + return err + } + token := identifyToken(prefix) + switch token { + case TokenDecimal: + // in this case, the value is just the byte itself + return fn(receiver, []byte{prefix}) + case TokenShortBlob: + sz := int(token.Diff(prefix)) + str, err := nextFull(d.Buffer, sz) + if err != nil { + return err + } + return fn(receiver, str) + case TokenLongBlob: + lenSz := int(token.Diff(prefix)) + sz, err := nextBeInt(d.Buffer, lenSz) + if err != nil { + return err + } + str, err := nextFull(d.Buffer, sz) + if err != nil { + return err + } + return fn(receiver, str) + default: + return fmt.Errorf("%w: DecodeBlob on list token", ErrDecode) + } + } +} + +func DecodeDecimal[T any](fn func(*T, byte) error, receiver *T) func(d *Decoder) error { + return func(d *Decoder) error { + // figure out what we are reading + prefix, err := d.ReadByte() + if err != nil { + return err + } + token := identifyToken(prefix) + switch token { + case TokenDecimal: + // in this case, the value is just the byte itself + return fn(receiver, prefix) + default: + return fmt.Errorf("%w: DecodeDecimal on non-decimal token", ErrDecode) + } + } } diff --git a/rlp/encoder.go b/rlp/encoder.go index 972286e4b..402378b18 100644 --- a/rlp/encoder.go +++ b/rlp/encoder.go @@ -39,13 +39,13 @@ func (e *Encoder) Str(str []byte) *Encoder { // String will assume your string is less than 56 bytes long, and do no validation as such func (e *Encoder) ShortString(str []byte) *Encoder { - return e.Byte(TokenShortString.Plus(byte(len(str)))).Bytes(str) + return e.Byte(TokenShortBlob.Plus(byte(len(str)))).Bytes(str) } // String will assume your string is greater than 55 bytes long, and do no validation as such func (e *Encoder) LongString(str []byte) *Encoder { // write the indicator token - e.Byte(byte(TokenLongString)) + e.Byte(byte(TokenLongBlob)) // write the integer, knowing that we appended n bytes n := putUint(e, len(str)) // so we knw the indicator token was n+1 bytes ago. diff --git a/rlp/readme.md b/rlp/readme.md new file mode 100644 index 000000000..74e9f96ee --- /dev/null +++ b/rlp/readme.md @@ -0,0 +1,11 @@ +## rlp + + +TERMINOLOGY: + +``` +RLP string = "Blob" // this is so we don't conflict with existing go name for String +RLP list = "List" // luckily we can keep using list name since go doesn't use it +RLP single byte number = "Decimal" // for numbers from 1-127. a special case +``` + diff --git a/rlp/unmarshaler.go b/rlp/unmarshaler.go index 6ccca7635..4c11faff6 100644 --- a/rlp/unmarshaler.go +++ b/rlp/unmarshaler.go @@ -28,30 +28,6 @@ func unmarshal(buf *bytes.Buffer, val any) error { return nil } -func nextFull(dat *bytes.Buffer, size int) ([]byte, error) { - d := dat.Next(size) - if len(d) != size { - return nil, ErrUnexpectedEOF - } - return d, nil -} - -// BeInt parses Big Endian representation of an integer from given payload at given position -func decodeBeInt(w *bytes.Buffer, length int) (int, error) { - var r int - dat, err := nextFull(w, length) - if err != nil { - return 0, ErrUnexpectedEOF - } - if length > 0 && dat[0] == 0 { - return 0, fmt.Errorf("%w: integer encoding for RLP must not have leading zeros: %x", ErrParse, dat) - } - for _, b := range dat[0:length] { - r = (r << 8) | int(b) - } - return r, nil -} - func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { if um, ok := rv.Interface().(Unmarshaler); ok { return um.UnmarshalRLP(w.Bytes()) @@ -76,16 +52,16 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { default: return fmt.Errorf("%w: decimal must be unmarshal into integer type", ErrDecode) } - case TokenShortString: + case TokenShortBlob: sz := int(token.Diff(prefix)) str, err := nextFull(w, sz) if err != nil { return err } - return putString(str, v, rv) - case TokenLongString: + return putBlob(str, v, rv) + case TokenLongBlob: lenSz := int(token.Diff(prefix)) - sz, err := decodeBeInt(w, lenSz) + sz, err := nextBeInt(w, lenSz) if err != nil { return err } @@ -93,7 +69,7 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { if err != nil { return err } - return putString(str, v, rv) + return putBlob(str, v, rv) case TokenShortList: sz := int(token.Diff(prefix)) buf, err := nextFull(w, sz) @@ -103,7 +79,7 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { return reflectList(bytes.NewBuffer(buf), v, rv) case TokenLongList: lenSz := int(token.Diff(prefix)) - sz, err := decodeBeInt(w, lenSz) + sz, err := nextBeInt(w, lenSz) if err != nil { return err } @@ -118,7 +94,7 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { return nil } -func putString(w []byte, v reflect.Value, rv reflect.Value) error { +func putBlob(w []byte, v reflect.Value, rv reflect.Value) error { switch v.Kind() { case reflect.String: v.SetString(string(w)) diff --git a/rlp/util.go b/rlp/util.go index 0ba45efa0..0035e1470 100644 --- a/rlp/util.go +++ b/rlp/util.go @@ -1,6 +1,10 @@ package rlp -type Token byte +import ( + "bytes" +) + +type Token int32 func (T Token) Plus(n byte) byte { return byte(T) + n @@ -11,13 +15,13 @@ func (T Token) Diff(n byte) byte { } const ( - TokenDecimal Token = 0x00 - TokenShortString Token = 0x80 - TokenLongString Token = 0xb7 - TokenShortList Token = 0xc0 - TokenLongList Token = 0xf7 + TokenDecimal Token = 0x00 + TokenShortBlob Token = 0x80 + TokenLongBlob Token = 0xb7 + TokenShortList Token = 0xc0 + TokenLongList Token = 0xf7 - TokenUnknown Token = 0xff + TokenUnknown Token = -1 ) func identifyToken(b byte) Token { @@ -25,9 +29,9 @@ func identifyToken(b byte) Token { case b >= 0 && b <= 127: return TokenDecimal case b >= 128 && b <= 183: - return TokenShortString + return TokenShortBlob case b >= 184 && b <= 191: - return TokenLongString + return TokenLongBlob case b >= 192 && b <= 247: return TokenShortList case b >= 248 && b <= 255: @@ -35,3 +39,20 @@ func identifyToken(b byte) Token { } return TokenUnknown } + +// BeInt parses Big Endian representation of an integer from given payload at given position +func nextBeInt(w *bytes.Buffer, length int) (int, error) { + dat, err := nextFull(w, length) + if err != nil { + return 0, ErrUnexpectedEOF + } + return BeInt(dat, 0, length) +} + +func nextFull(dat *bytes.Buffer, size int) ([]byte, error) { + d := dat.Next(size) + if len(d) != size { + return nil, ErrUnexpectedEOF + } + return d, nil +} From 5a9ae198a4649f43dea88b619df9b4ef87f9503c Mon Sep 17 00:00:00 2001 From: a Date: Wed, 6 Sep 2023 15:55:47 -0500 Subject: [PATCH 16/34] wip --- rlp/decoder.go | 198 +++++++++++++++------ rlp/types.go | 59 ++++++ rlp/unmarshaler.go | 13 +- rlp/util.go | 19 +- types/txn.go | 13 ++ types/txn_decode.go | 423 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 655 insertions(+), 70 deletions(-) create mode 100644 rlp/types.go create mode 100644 types/txn_decode.go diff --git a/rlp/decoder.go b/rlp/decoder.go index 0a2ebe41a..037e066da 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -1,103 +1,189 @@ package rlp import ( - "bytes" "fmt" + "io" ) type Decoder struct { - *bytes.Buffer + buf *buf } func NewDecoder(buf []byte) *Decoder { return &Decoder{ - Buffer: bytes.NewBuffer(buf), + buf: newBuf(buf, 0), } } -func (d *Decoder) List() (*Decoder, error) { - w := d.Buffer +func (d *Decoder) Consumed() []byte { + return d.buf.u[:d.buf.off] +} + +func (d *Decoder) Underlying() []byte { + return d.buf.u +} + +func (d *Decoder) Len() int { + return d.buf.Len() +} + +func (d *Decoder) Offset() int { + return d.buf.Offset() +} + +func (d *Decoder) Bytes() []byte { + return d.buf.Bytes() +} + +func (d *Decoder) ReadByte() (n byte, err error) { + return d.buf.ReadByte() +} + +func (d *Decoder) ElemDec() (*Decoder, Token, error) { + a, t, err := d.Elem() + return NewDecoder(a), t, err +} + +func (d *Decoder) Elem() ([]byte, Token, error) { + w := d.buf // figure out what we are reading prefix, err := w.ReadByte() if err != nil { - return nil, err + return nil, TokenUnknown, err } token := identifyToken(prefix) // switch switch token { + case TokenDecimal: + // in this case, the value is just the byte itself + return []byte{prefix}, token, nil case TokenShortList: sz := int(token.Diff(prefix)) buf, err := nextFull(w, sz) if err != nil { - return nil, err + return nil, token, err } - return NewDecoder(buf), nil + return buf, token, nil case TokenLongList: lenSz := int(token.Diff(prefix)) sz, err := nextBeInt(w, lenSz) if err != nil { - return nil, err + return nil, token, err } buf, err := nextFull(w, sz) if err != nil { - return nil, err + return nil, token, err + } + return buf, token, nil + case TokenShortBlob: + sz := int(token.Diff(prefix)) + str, err := nextFull(w, sz) + if err != nil { + return nil, token, err + } + return str, token, nil + case TokenLongBlob: + lenSz := int(token.Diff(prefix)) + sz, err := nextBeInt(w, lenSz) + if err != nil { + return nil, token, err + } + str, err := nextFull(w, sz) + if err != nil { + return nil, token, err } - return NewDecoder(buf), nil + return str, token, nil default: - return nil, fmt.Errorf("%w: List on non-list token", ErrDecode) + return nil, token, fmt.Errorf("%w: unknown token", ErrDecode) } } -func DecodeBlob[T any](fn func(*T, []byte) error, receiver *T) func(d *Decoder) error { - return func(d *Decoder) error { - // figure out what we are reading - prefix, err := d.ReadByte() - if err != nil { - return err - } - token := identifyToken(prefix) - switch token { - case TokenDecimal: - // in this case, the value is just the byte itself - return fn(receiver, []byte{prefix}) - case TokenShortBlob: - sz := int(token.Diff(prefix)) - str, err := nextFull(d.Buffer, sz) - if err != nil { - return err - } - return fn(receiver, str) - case TokenLongBlob: - lenSz := int(token.Diff(prefix)) - sz, err := nextBeInt(d.Buffer, lenSz) - if err != nil { - return err - } - str, err := nextFull(d.Buffer, sz) +func ReadElem[T any](d *Decoder, fn func(*T, []byte) error, receiver *T) error { + buf, token, err := d.Elem() + if err != nil { + return err + } + switch token { + case TokenDecimal, + TokenShortBlob, + TokenLongBlob, + TokenShortList, + TokenLongList: + // in this case, the value is just the byte itself + return fn(receiver, buf) + default: + return fmt.Errorf("%w: ReadElem found unexpected token", ErrDecode) + } +} + +func (d *Decoder) ForList(fn func(*Decoder) error) error { + // figure out what we are reading + buf, token, err := d.Elem() + if err != nil { + return err + } + switch token { + case TokenShortList, TokenLongList: + dec := NewDecoder(buf) + for dec.buf.Len() > 0 { + err := fn(d) if err != nil { return err } - return fn(receiver, str) - default: - return fmt.Errorf("%w: DecodeBlob on list token", ErrDecode) } + return nil + default: + return fmt.Errorf("%w: ForList on non-list", ErrDecode) } } -func DecodeDecimal[T any](fn func(*T, byte) error, receiver *T) func(d *Decoder) error { - return func(d *Decoder) error { - // figure out what we are reading - prefix, err := d.ReadByte() - if err != nil { - return err - } - token := identifyToken(prefix) - switch token { - case TokenDecimal: - // in this case, the value is just the byte itself - return fn(receiver, prefix) - default: - return fmt.Errorf("%w: DecodeDecimal on non-decimal token", ErrDecode) - } +type buf struct { + u []byte + off int +} + +func newBuf(u []byte, off int) *buf { + return &buf{u: u, off: off} +} + +func (b *buf) empty() bool { return len(b.u) <= b.off } + +func (b *buf) PeekByte() (n byte, err error) { + if len(b.u) <= b.off+1 { + return 0, io.EOF + } + return b.u[b.off+1], nil +} +func (b *buf) ReadByte() (n byte, err error) { + if len(b.u) <= b.off+1 { + return 0, io.EOF + } + b.off++ + return b.u[b.off], nil +} + +func (b *buf) Next(n int) (xs []byte) { + m := b.Len() + if n > m { + n = m } + data := b.u[b.off : b.off+n] + b.off += n + return data } + +func (b *buf) Offset() int { + return b.off +} + +func (b *buf) Bytes() []byte { return b.u[b.off:] } + +func (b *buf) String() string { + if b == nil { + // Special case, useful in debugging. + return "" + } + return string(b.u[b.off:]) +} + +func (b *buf) Len() int { return len(b.u) - b.off } diff --git a/rlp/types.go b/rlp/types.go new file mode 100644 index 000000000..f33bdfdc2 --- /dev/null +++ b/rlp/types.go @@ -0,0 +1,59 @@ +package rlp + +import ( + "fmt" + + "github.com/holiman/uint256" +) + +func Bytes(dst *[]byte, src []byte) error { + if len(*dst) < len(src) { + (*dst) = make([]byte, len(src)) + } + copy(*dst, src) + return nil +} +func BytesExact(dst *[]byte, src []byte) error { + if len(*dst) != len(src) { + return fmt.Errorf("%w: BytesExact no match", ErrDecode) + } + copy(*dst, src) + return nil +} + +func Uint256(dst *uint256.Int, src []byte) error { + if len(src) > 32 { + return fmt.Errorf("%w: uint256 must not be more than 32 bytes long, got %d", ErrParse, len(src)) + } + if len(src) > 0 && src[0] == 0 { + return fmt.Errorf("%w: integer encoding for RLP must not have leading zeros: %x", ErrParse, src) + } + dst.SetBytes(src) + return nil +} + +func Uint64(dst *uint64, src []byte) error { + var r uint64 + for _, b := range src { + r = (r << 8) | uint64(b) + } + (*dst) = r + return nil +} + +func IsEmpty(dst *bool, src []byte) error { + if len(src) == 0 { + (*dst) = true + } else { + (*dst) = false + } + return nil +} +func BlobLength(dst *int, src []byte) error { + (*dst) = len(src) + return nil +} + +func Skip(dst *int, src []byte) error { + return nil +} diff --git a/rlp/unmarshaler.go b/rlp/unmarshaler.go index 4c11faff6..16c42a1f2 100644 --- a/rlp/unmarshaler.go +++ b/rlp/unmarshaler.go @@ -1,7 +1,6 @@ package rlp import ( - "bytes" "fmt" "reflect" ) @@ -11,11 +10,11 @@ type Unmarshaler interface { } func Unmarshal(data []byte, val any) error { - buf := bytes.NewBuffer(data) + buf := newBuf(data, 0) return unmarshal(buf, val) } -func unmarshal(buf *bytes.Buffer, val any) error { +func unmarshal(buf *buf, val any) error { rv := reflect.ValueOf(val) if rv.Kind() != reflect.Pointer || rv.IsNil() { return fmt.Errorf("%w: v must be ptr", ErrDecode) @@ -28,7 +27,7 @@ func unmarshal(buf *bytes.Buffer, val any) error { return nil } -func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { +func reflectAny(w *buf, v reflect.Value, rv reflect.Value) error { if um, ok := rv.Interface().(Unmarshaler); ok { return um.UnmarshalRLP(w.Bytes()) } @@ -76,7 +75,7 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { if err != nil { return err } - return reflectList(bytes.NewBuffer(buf), v, rv) + return reflectList(newBuf(buf, 0), v, rv) case TokenLongList: lenSz := int(token.Diff(prefix)) sz, err := nextBeInt(w, lenSz) @@ -87,7 +86,7 @@ func reflectAny(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { if err != nil { return err } - return reflectList(bytes.NewBuffer(buf), v, rv) + return reflectList(newBuf(buf, 0), v, rv) case TokenUnknown: return fmt.Errorf("%w: unknown token", ErrDecode) } @@ -127,7 +126,7 @@ func putBlob(w []byte, v reflect.Value, rv reflect.Value) error { return nil } -func reflectList(w *bytes.Buffer, v reflect.Value, rv reflect.Value) error { +func reflectList(w *buf, v reflect.Value, rv reflect.Value) error { switch v.Kind() { case reflect.Invalid: // do nothing diff --git a/rlp/util.go b/rlp/util.go index 0035e1470..83e6b48a5 100644 --- a/rlp/util.go +++ b/rlp/util.go @@ -1,9 +1,5 @@ package rlp -import ( - "bytes" -) - type Token int32 func (T Token) Plus(n byte) byte { @@ -14,6 +10,14 @@ func (T Token) Diff(n byte) byte { return n - byte(T) } +func (T Token) IsListType() bool { + return T == TokenLongList || T == TokenShortList +} + +func (T Token) IsBlobType() bool { + return T == TokenLongBlob || T == TokenShortBlob +} + const ( TokenDecimal Token = 0x00 TokenShortBlob Token = 0x80 @@ -21,7 +25,8 @@ const ( TokenShortList Token = 0xc0 TokenLongList Token = 0xf7 - TokenUnknown Token = -1 + TokenUnknown Token = 0xff01 + TokenEOF Token = 0xdead ) func identifyToken(b byte) Token { @@ -41,7 +46,7 @@ func identifyToken(b byte) Token { } // BeInt parses Big Endian representation of an integer from given payload at given position -func nextBeInt(w *bytes.Buffer, length int) (int, error) { +func nextBeInt(w *buf, length int) (int, error) { dat, err := nextFull(w, length) if err != nil { return 0, ErrUnexpectedEOF @@ -49,7 +54,7 @@ func nextBeInt(w *bytes.Buffer, length int) (int, error) { return BeInt(dat, 0, length) } -func nextFull(dat *bytes.Buffer, size int) ([]byte, error) { +func nextFull(dat *buf, size int) ([]byte, error) { d := dat.Next(size) if len(d) != size { return nil, ErrUnexpectedEOF diff --git a/types/txn.go b/types/txn.go index a1e7f8224..4944d95b4 100644 --- a/types/txn.go +++ b/types/txn.go @@ -146,7 +146,20 @@ func PeekTransactionType(serialized []byte) (byte, error) { // It also performs syntactic validation of the transactions. // wrappedWithBlobs means that for blob (type 3) transactions the full version with blobs/commitments/proofs is expected // (see https://eips.ethereum.org/EIPS/eip-4844#networking). +// +// [ ] || [ nonce, price, limit, to, value, data, y,r,s] +// 0x01 || [chain_id, nonce, price, limit, to, value, data, access_list, y,r,s] +// 0x02 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, y,r,s] +// 0x03 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s] +// 0x03 ||[[chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s], blobs, commitments, proofs] func (ctx *TxParseContext) ParseTransaction(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (p int, err error) { + err = ctx.parseTransaction2(payload, pos, slot, sender, hasEnvelope, wrappedWithBlobs, validateHash) + if err != nil { + return 0, fmt.Errorf("%w: %w", ErrParseTxn, err) + } + return len(payload), nil +} +func (ctx *TxParseContext) parseTransaction(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (p int, err error) { if len(payload) == 0 { return 0, fmt.Errorf("%w: empty rlp", ErrParseTxn) } diff --git a/types/txn_decode.go b/types/txn_decode.go new file mode 100644 index 000000000..12d647647 --- /dev/null +++ b/types/txn_decode.go @@ -0,0 +1,423 @@ +package types + +import ( + "encoding/binary" + "fmt" + "io" + "math/bits" + + gokzg4844 "github.com/crate-crypto/go-kzg-4844" + "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon-lib/common/u256" + "github.com/ledgerwatch/erigon-lib/crypto" + "github.com/ledgerwatch/erigon-lib/rlp" + "github.com/ledgerwatch/secp256k1" +) + +func (ctx *TxParseContext) parseTransaction2(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (err error) { + if len(payload) == 0 { + return fmt.Errorf("empty rlp") + } + if ctx.withSender && len(sender) != 20 { + return fmt.Errorf("expect sender buffer of len 20") + } + decoder := rlp.NewDecoder(payload[pos:]) + + dec, tok, err := decoder.ElemDec() + if err != nil { + return fmt.Errorf("size prefix: %w", err) //nolint + } + + if dec.Len() == 0 { + return fmt.Errorf("transaction must be either 1 list or 1 string") + } + if dec.Len() == 1 && !tok.IsListType() { + if hasEnvelope { + return fmt.Errorf("expected envelope in the payload, got %x", dec.Bytes()[0]) + } + } + + // Legacy transactions have list Prefix, whereas EIP-2718 transactions have string Prefix + // therefore we assign the first returned value of Prefix function (list) to legacy variable + switch { + case tok.IsListType(): + slot.Rlp = append(make([]byte, 0, dec.Len()), dec.Bytes()...) + slot.Size = uint32(len(slot.Rlp)) + slot.Type = LegacyTxType + case tok == rlp.TokenDecimal: + slot.Type, err = dec.ReadByte() + if err != nil { + return fmt.Errorf("couldnt read txn type: %w", err) + } + if slot.Type > BlobTxType { + return fmt.Errorf("unknown transaction type: %d", slot.Type) + } + dec, tok, err = dec.ElemDec() + if err != nil { + return err + } + if !tok.IsListType() { + return fmt.Errorf("expected list token") + } + slot.Rlp = append(make([]byte, 0, dec.Len()), dec.Bytes()...) + slot.Size = uint32(len(slot.Rlp)) + default: + return fmt.Errorf("expected list or decimal token") + } + + bodyDecoder := dec + // if its a blob transaction, we actually need to enter a nested list, since its [tx_payload_body, blobs, commitments, proofs] + if slot.Type == BlobTxType && wrappedWithBlobs { + bodyDecoder, _, err = dec.ElemDec() + if err != nil { + return fmt.Errorf("wrapped blob tx: %w", err) //nolint + } + } + + err = ctx.parseTransactionBody2(bodyDecoder, slot, sender, validateHash) + if err != nil { + return err + } + + // so its a blob transaction and we need to do the extra stuff... + if slot.Type == BlobTxType && wrappedWithBlobs { + if err := ctx.parseBlobs(dec, slot); err != nil { + return err + } + if err := ctx.parseCommitments(dec, slot); err != nil { + return err + } + if err := ctx.parseProofs(dec, slot); err != nil { + return err + } + if len(slot.Blobs) != len(slot.Commitments) { + return fmt.Errorf("blob count != commitment count") + } + if len(slot.Commitments) != len(slot.Proofs) { + return fmt.Errorf("commitment count != proof count") + } + if len(slot.BlobHashes) != len(slot.Blobs) { + return fmt.Errorf("blob count != blob hash count") + } + } + return err +} + +func (ctx *TxParseContext) parseCommitments(dec *rlp.Decoder, slot *TxSlot) (err error) { + err = dec.ForList(func(d *rlp.Decoder) error { + var blob gokzg4844.KZGCommitment + blobSlice := blob[:] + err := rlp.ReadElem(dec, rlp.BytesExact, &blobSlice) + if err != nil { + return err + } + slot.Commitments = append(slot.Commitments, blob) + return nil + }) + + if err != nil { + return err + } + return nil +} + +func (ctx *TxParseContext) parseProofs(dec *rlp.Decoder, slot *TxSlot) (err error) { + err = dec.ForList(func(d *rlp.Decoder) error { + var blob gokzg4844.KZGProof + blobSlice := blob[:] + err := rlp.ReadElem(dec, rlp.BytesExact, &blobSlice) + if err != nil { + return err + } + slot.Proofs = append(slot.Proofs, blob) + return nil + }) + if err != nil { + return err + } + return nil +} + +func (ctx *TxParseContext) parseBlobs(dec *rlp.Decoder, slot *TxSlot) (err error) { + err = dec.ForList(func(d *rlp.Decoder) error { + var blob []byte + err := rlp.ReadElem(dec, rlp.Bytes, &blob) + if err != nil { + return err + } + slot.Blobs = append(slot.Blobs, blob) + return nil + }) + if err != nil { + return err + } + + return nil +} + +func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, sender []byte, validateHash func([]byte) error) (err error) { + legacy := slot.Type == LegacyTxType + + // Compute transaction hash + ctx.Keccak1.Reset() + ctx.Keccak2.Reset() + if ctx.validateRlp != nil { + if err := ctx.validateRlp(slot.Rlp); err != nil { + return err + } + } + + if !legacy { + err = rlp.ReadElem(dec, rlp.Uint256, &ctx.ChainID) + if err != nil { + return fmt.Errorf("bad chainId: %w", err) //nolint + } + if ctx.ChainID.IsZero() { // zero indicates that the chain ID was not specified in the tx. + if ctx.chainIDRequired { + return fmt.Errorf("chainID is required") + } + ctx.ChainID.Set(&ctx.cfg.ChainID) + } + if !ctx.ChainID.Eq(&ctx.cfg.ChainID) { + return fmt.Errorf("%s, %d (expected %d)", "invalid chainID", ctx.ChainID.Uint64(), ctx.cfg.ChainID.Uint64()) + } + } + // Next follows the nonce, which we need to parse + err = rlp.ReadElem(dec, rlp.Uint64, &slot.Nonce) + if err != nil { + return fmt.Errorf("nonce: %s", err) //nolint + } + // Next follows gas price or tip + err = rlp.ReadElem(dec, rlp.Uint256, &slot.Tip) + if err != nil { + return fmt.Errorf("tip: %s", err) //nolint + } + // Next follows feeCap, but only for dynamic fee transactions, for legacy transaction, it is + // equal to tip + if slot.Type < DynamicFeeTxType { + slot.FeeCap = slot.Tip + } else { + err = rlp.ReadElem(dec, rlp.Uint256, &slot.FeeCap) + if err != nil { + return fmt.Errorf("feeCap: %s", err) //nolint + } + } + // gas limit + err = rlp.ReadElem(dec, rlp.Uint64, &slot.Gas) + if err != nil { + return fmt.Errorf("gas: %s", err) //nolint + } + // recipient + err = rlp.ReadElem(dec, rlp.IsEmpty, &slot.Creation) + if err != nil { + return fmt.Errorf("value: %s", err) //nolint + } + // Next follows value + err = rlp.ReadElem(dec, rlp.Uint256, &slot.Value) + if err != nil { + return fmt.Errorf("value: %s", err) //nolint + } + // Next goes data, but we are only interesting in its length + err = rlp.ReadElem(dec, func(i *int, b []byte) error { + slot.DataLen = len(b) + for _, byt := range b { + if byt != 0 { + slot.DataNonZeroLen++ + } + } + return nil + }, nil) + if err != nil { + return fmt.Errorf("data len: %s", err) //nolint + } + // Zero and non-zero bytes are priced differently + slot.DataNonZeroLen = 0 + // Next follows access list for non-legacy transactions, we are only interesting in number of addresses and storage keys + if !legacy { + err = dec.ForList(func(ld *rlp.Decoder) error { + slot.AlAddrCount++ + err := rlp.ReadElem(ld, rlp.Skip, nil) + if err != nil { + return err + } + err = ld.ForList(func(sk *rlp.Decoder) error { + slot.AlStorCount++ + err := rlp.ReadElem(sk, rlp.Skip, nil) + if err != nil { + return err + } + return nil + }) + return err + }) + } + + if slot.Type == BlobTxType { + err = rlp.ReadElem(dec, rlp.Uint256, &slot.BlobFeeCap) + if err != nil { + return fmt.Errorf("blob fee cap: %s", err) //nolint + } + dec.ForList(func(dec *rlp.Decoder) error { + var blob common.Hash + blobSlice := blob[:] + err := rlp.ReadElem(dec, rlp.BytesExact, &blobSlice) + if err != nil { + return err + } + slot.BlobHashes = append(slot.BlobHashes, blob) + return nil + }) + } + // This is where the data for Sighash ends + + // Next follows V of the signature + var vByte byte + var chainIDBits, chainIDLen int + if legacy { + err = rlp.ReadElem(dec, rlp.Uint256, &ctx.V) + if err != nil { + return fmt.Errorf("V: %s", err) //nolint + } + ctx.IsProtected = ctx.V.Eq(u256.N27) || ctx.V.Eq(u256.N28) + // Compute chainId from V + if ctx.IsProtected { + // Do not add chain id and two extra zeros + vByte = byte(ctx.V.Uint64() - 27) + ctx.ChainID.Set(&ctx.cfg.ChainID) + } else { + ctx.ChainID.Sub(&ctx.V, u256.N35) + ctx.ChainID.Rsh(&ctx.ChainID, 1) + if !ctx.ChainID.Eq(&ctx.cfg.ChainID) { + return fmt.Errorf("%s, %d (expected %d)", "invalid chainID", ctx.ChainID.Uint64(), ctx.cfg.ChainID.Uint64()) + } + + chainIDBits = ctx.ChainID.BitLen() + if chainIDBits <= 7 { + chainIDLen = 1 + } else { + chainIDLen = common.BitLenToByteLen(chainIDBits) // It is always < 56 bytes + } + ctx.DeriveChainID.Sub(&ctx.V, &ctx.ChainIDMul) + vByte = byte(ctx.DeriveChainID.Sub(&ctx.DeriveChainID, u256.N8).Uint64() - 27) + } + } else { + var v uint64 + err = rlp.ReadElem(dec, rlp.Uint64, &v) + if err != nil { + return fmt.Errorf("V: %s", err) //nolint + } + if v > 1 { + return fmt.Errorf("V is loo large: %d", v) + } + vByte = byte(v) + ctx.IsProtected = true + } + + // Next follows R of the signature + err = rlp.ReadElem(dec, rlp.Uint256, &ctx.R) + if err != nil { + return fmt.Errorf("R: %s", err) //nolint + } + // New follows S of the signature + err = rlp.ReadElem(dec, rlp.Uint256, &ctx.S) + if err != nil { + return fmt.Errorf("S: %s", err) //nolint + } + + if _, err = ctx.Keccak1.Write([]byte{slot.Type}); err != nil { + return fmt.Errorf("computing IdHash: %s", err) //nolint + } + + // For legacy transactions, hash the full payload + if legacy { + if _, err = ctx.Keccak1.Write(dec.Consumed()); err != nil { + return fmt.Errorf("computing IdHash: %s", err) //nolint + } + } + _, _ = ctx.Keccak1.(io.Reader).Read(slot.IDHash[:32]) + if validateHash != nil { + if err := validateHash(slot.IDHash[:32]); err != nil { + return err + } + } + + if !ctx.withSender { + return nil + } + + if !crypto.TransactionSignatureIsValid(vByte, &ctx.R, &ctx.S, ctx.allowPreEip2s && legacy) { + return fmt.Errorf("invalid v, r, s: %d, %s, %s", vByte, &ctx.R, &ctx.S) + } + + // Computing sigHash (hash used to recover sender from the signature) + // Write len Prefix to the Sighash + if dec.Offset() < 56 { + ctx.buf[0] = byte(dec.Offset()) + 192 + if _, err := ctx.Keccak2.Write(ctx.buf[:1]); err != nil { + return fmt.Errorf("computing signHash (hashing len Prefix): %s", err) //nolint + } + } else { + beLen := common.BitLenToByteLen(bits.Len(uint(dec.Offset()))) + binary.BigEndian.PutUint64(ctx.buf[1:], uint64(dec.Offset())) + ctx.buf[8-beLen] = byte(beLen) + 247 + if _, err := ctx.Keccak2.Write(ctx.buf[8-beLen : 9]); err != nil { + return fmt.Errorf("computing signHash (hashing len Prefix): %s", err) //nolint + } + } + if _, err = ctx.Keccak2.Write(dec.Consumed()); err != nil { + return fmt.Errorf("computing signHash: %s", err) //nolint + } + if legacy { + if chainIDLen > 0 { + if chainIDBits <= 7 { + ctx.buf[0] = byte(ctx.ChainID.Uint64()) + if _, err := ctx.Keccak2.Write(ctx.buf[:1]); err != nil { + return fmt.Errorf("computing signHash (hashing legacy chainId): %s", err) //nolint + } + } else { + binary.BigEndian.PutUint64(ctx.buf[1:9], ctx.ChainID[3]) + binary.BigEndian.PutUint64(ctx.buf[9:17], ctx.ChainID[2]) + binary.BigEndian.PutUint64(ctx.buf[17:25], ctx.ChainID[1]) + binary.BigEndian.PutUint64(ctx.buf[25:33], ctx.ChainID[0]) + ctx.buf[32-chainIDLen] = 128 + byte(chainIDLen) + if _, err = ctx.Keccak2.Write(ctx.buf[32-chainIDLen : 33]); err != nil { + return fmt.Errorf("computing signHash (hashing legacy chainId): %s", err) //nolint + } + } + // Encode two zeros + ctx.buf[0] = 128 + ctx.buf[1] = 128 + if _, err := ctx.Keccak2.Write(ctx.buf[:2]); err != nil { + return fmt.Errorf("computing signHash (hashing zeros after legacy chainId): %s", err) //nolint + } + } + } + // Squeeze Sighash + _, _ = ctx.Keccak2.(io.Reader).Read(ctx.Sighash[:32]) + //ctx.keccak2.Sum(ctx.Sighash[:0]) + binary.BigEndian.PutUint64(ctx.Sig[0:8], ctx.R[3]) + binary.BigEndian.PutUint64(ctx.Sig[8:16], ctx.R[2]) + binary.BigEndian.PutUint64(ctx.Sig[16:24], ctx.R[1]) + binary.BigEndian.PutUint64(ctx.Sig[24:32], ctx.R[0]) + binary.BigEndian.PutUint64(ctx.Sig[32:40], ctx.S[3]) + binary.BigEndian.PutUint64(ctx.Sig[40:48], ctx.S[2]) + binary.BigEndian.PutUint64(ctx.Sig[48:56], ctx.S[1]) + binary.BigEndian.PutUint64(ctx.Sig[56:64], ctx.S[0]) + ctx.Sig[64] = vByte + // recover sender + if _, err = secp256k1.RecoverPubkeyWithContext(secp256k1.DefaultContext, ctx.Sighash[:], ctx.Sig[:], ctx.buf[:0]); err != nil { + return fmt.Errorf("recovering sender from signature: %s", err) //nolint + } + //apply keccak to the public key + ctx.Keccak2.Reset() + if _, err = ctx.Keccak2.Write(ctx.buf[1:65]); err != nil { + return fmt.Errorf("computing sender from public key: %s", err) //nolint + } + // squeeze the hash of the public key + //ctx.keccak2.Sum(ctx.buf[:0]) + _, _ = ctx.Keccak2.(io.Reader).Read(ctx.buf[:32]) + //take last 20 bytes as address + copy(sender, ctx.buf[12:32]) + + return nil +} From 4e9df7144583ef0e25f0e8049e55652b795a0724 Mon Sep 17 00:00:00 2001 From: a Date: Wed, 6 Sep 2023 16:17:21 -0500 Subject: [PATCH 17/34] wip --- rlp/decoder.go | 9 ++++----- types/txn_decode.go | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/rlp/decoder.go b/rlp/decoder.go index 037e066da..6bbe9688f 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -109,7 +109,6 @@ func ReadElem[T any](d *Decoder, fn func(*T, []byte) error, receiver *T) error { TokenLongBlob, TokenShortList, TokenLongList: - // in this case, the value is just the byte itself return fn(receiver, buf) default: return fmt.Errorf("%w: ReadElem found unexpected token", ErrDecode) @@ -149,17 +148,17 @@ func newBuf(u []byte, off int) *buf { func (b *buf) empty() bool { return len(b.u) <= b.off } func (b *buf) PeekByte() (n byte, err error) { - if len(b.u) <= b.off+1 { + if len(b.u) <= b.off { return 0, io.EOF } - return b.u[b.off+1], nil + return b.u[b.off], nil } func (b *buf) ReadByte() (n byte, err error) { - if len(b.u) <= b.off+1 { + if len(b.u) <= b.off { return 0, io.EOF } b.off++ - return b.u[b.off], nil + return b.u[b.off-1], nil } func (b *buf) Next(n int) (xs []byte) { diff --git a/types/txn_decode.go b/types/txn_decode.go index 12d647647..5a3279fe2 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -33,7 +33,7 @@ func (ctx *TxParseContext) parseTransaction2(payload []byte, pos int, slot *TxSl } if dec.Len() == 1 && !tok.IsListType() { if hasEnvelope { - return fmt.Errorf("expected envelope in the payload, got %x", dec.Bytes()[0]) + return fmt.Errorf("expected envelope in the payload, got %x", dec.Bytes()) } } From 073af043b08087f844617c2f03a1eaa64d370bb9 Mon Sep 17 00:00:00 2001 From: a Date: Wed, 6 Sep 2023 23:35:33 -0500 Subject: [PATCH 18/34] wip --- rlp/decoder.go | 48 +++++++++++++---------------- rlp/util.go | 21 +++++++++++++ types/txn.go | 53 +++++++++++++++++--------------- types/txn_decode.go | 73 +++++++++++++++++++++++++++++--------------- types/txn_packets.go | 22 +++++++++++++ 5 files changed, 140 insertions(+), 77 deletions(-) diff --git a/rlp/decoder.go b/rlp/decoder.go index 6bbe9688f..8baf0d52a 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -52,50 +52,44 @@ func (d *Decoder) Elem() ([]byte, Token, error) { return nil, TokenUnknown, err } token := identifyToken(prefix) - // switch + + var ( + buf []byte + sz int + lenSz int + ) + // switch on the token switch token { case TokenDecimal: // in this case, the value is just the byte itself - return []byte{prefix}, token, nil + buf = []byte{prefix} case TokenShortList: - sz := int(token.Diff(prefix)) - buf, err := nextFull(w, sz) - if err != nil { - return nil, token, err - } - return buf, token, nil + sz = int(token.Diff(prefix)) + buf, err = nextFull(w, sz) case TokenLongList: - lenSz := int(token.Diff(prefix)) - sz, err := nextBeInt(w, lenSz) - if err != nil { - return nil, token, err - } - buf, err := nextFull(w, sz) + lenSz = int(token.Diff(prefix)) + sz, err = nextBeInt(w, lenSz) if err != nil { return nil, token, err } - return buf, token, nil + buf, err = nextFull(w, sz) case TokenShortBlob: sz := int(token.Diff(prefix)) - str, err := nextFull(w, sz) - if err != nil { - return nil, token, err - } - return str, token, nil + buf, err = nextFull(w, sz) case TokenLongBlob: lenSz := int(token.Diff(prefix)) sz, err := nextBeInt(w, lenSz) if err != nil { return nil, token, err } - str, err := nextFull(w, sz) - if err != nil { - return nil, token, err - } - return str, token, nil + buf, err = nextFull(w, sz) default: return nil, token, fmt.Errorf("%w: unknown token", ErrDecode) } + if err != nil { + return nil, token, err + } + return buf, token, nil } func ReadElem[T any](d *Decoder, fn func(*T, []byte) error, receiver *T) error { @@ -116,7 +110,7 @@ func ReadElem[T any](d *Decoder, fn func(*T, []byte) error, receiver *T) error { } func (d *Decoder) ForList(fn func(*Decoder) error) error { - // figure out what we are reading + // grab the list bytes buf, token, err := d.Elem() if err != nil { return err @@ -125,7 +119,7 @@ func (d *Decoder) ForList(fn func(*Decoder) error) error { case TokenShortList, TokenLongList: dec := NewDecoder(buf) for dec.buf.Len() > 0 { - err := fn(d) + err := fn(dec) if err != nil { return err } diff --git a/rlp/util.go b/rlp/util.go index 83e6b48a5..6b04dfb2a 100644 --- a/rlp/util.go +++ b/rlp/util.go @@ -2,6 +2,27 @@ package rlp type Token int32 +func (T Token) String() string { + switch T { + case TokenDecimal: + return "decimal" + case TokenShortBlob: + return "short_blob" + case TokenLongBlob: + return "long_blob" + case TokenShortList: + return "short_list" + case TokenLongList: + return "long_list" + case TokenEOF: + return "eof" + case TokenUnknown: + return "unknown" + default: + return "nan" + } +} + func (T Token) Plus(n byte) byte { return byte(T) + n } diff --git a/types/txn.go b/types/txn.go index 4944d95b4..4931d747a 100644 --- a/types/txn.go +++ b/types/txn.go @@ -21,7 +21,6 @@ import ( "encoding/binary" "errors" "fmt" - "hash" "io" "math/bits" "sort" @@ -29,13 +28,13 @@ import ( gokzg4844 "github.com/crate-crypto/go-kzg-4844" "github.com/holiman/uint256" "github.com/ledgerwatch/secp256k1" - "golang.org/x/crypto/sha3" "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon-lib/common/fixedgas" "github.com/ledgerwatch/erigon-lib/common/length" "github.com/ledgerwatch/erigon-lib/common/u256" "github.com/ledgerwatch/erigon-lib/crypto" + "github.com/ledgerwatch/erigon-lib/crypto/cryptopool" "github.com/ledgerwatch/erigon-lib/gointerfaces/types" "github.com/ledgerwatch/erigon-lib/rlp" ) @@ -47,8 +46,6 @@ type TxParseConfig struct { // TxParseContext is object that is required to parse transactions and turn transaction payload into TxSlot objects // usage of TxContext helps avoid extra memory allocations type TxParseContext struct { - Keccak2 hash.Hash - Keccak1 hash.Hash validateRlp func([]byte) error ChainID uint256.Int // Signature values R uint256.Int // Signature values @@ -72,8 +69,6 @@ func NewTxParseContext(chainID uint256.Int) *TxParseContext { } ctx := &TxParseContext{ withSender: true, - Keccak1: sha3.NewLegacyKeccak256(), - Keccak2: sha3.NewLegacyKeccak256(), } // behave as of London enabled @@ -153,11 +148,13 @@ func PeekTransactionType(serialized []byte) (byte, error) { // 0x03 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s] // 0x03 ||[[chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s], blobs, commitments, proofs] func (ctx *TxParseContext) ParseTransaction(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (p int, err error) { - err = ctx.parseTransaction2(payload, pos, slot, sender, hasEnvelope, wrappedWithBlobs, validateHash) + //dec := rlp.NewDecoder(payload) + p, err = ctx.parseTransaction(payload, pos, slot, sender, hasEnvelope, wrappedWithBlobs, validateHash) if err != nil { return 0, fmt.Errorf("%w: %w", ErrParseTxn, err) } - return len(payload), nil + //return dec.Offset() + pos, nil + return p, nil } func (ctx *TxParseContext) parseTransaction(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (p int, err error) { if len(payload) == 0 { @@ -300,15 +297,21 @@ func (ctx *TxParseContext) parseTransactionBody(payload []byte, pos, p0 int, slo p = p0 legacy := slot.Type == LegacyTxType + k1 := cryptopool.GetLegacyKeccak256() + k2 := cryptopool.GetLegacyKeccak256() + defer cryptopool.ReturnLegacyKeccak256(k1) + defer cryptopool.ReturnLegacyKeccak256(k2) + + k1.Reset() + k2.Reset() + // Compute transaction hash - ctx.Keccak1.Reset() - ctx.Keccak2.Reset() if !legacy { typeByte := []byte{slot.Type} - if _, err = ctx.Keccak1.Write(typeByte); err != nil { + if _, err = k1.Write(typeByte); err != nil { return 0, fmt.Errorf("%w: computing IdHash (hashing type Prefix): %s", ErrParseTxn, err) //nolint } - if _, err = ctx.Keccak2.Write(typeByte); err != nil { + if _, err = k2.Write(typeByte); err != nil { return 0, fmt.Errorf("%w: computing signHash (hashing type Prefix): %s", ErrParseTxn, err) //nolint } dataPos, dataLen, err := rlp.List(payload, p) @@ -316,7 +319,7 @@ func (ctx *TxParseContext) parseTransactionBody(payload []byte, pos, p0 int, slo return 0, fmt.Errorf("%w: envelope Prefix: %s", ErrParseTxn, err) //nolint } // Hash the content of envelope, not the full payload - if _, err = ctx.Keccak1.Write(payload[p : dataPos+dataLen]); err != nil { + if _, err = k1.Write(payload[p : dataPos+dataLen]); err != nil { return 0, fmt.Errorf("%w: computing IdHash (hashing the envelope): %s", ErrParseTxn, err) //nolint } p = dataPos @@ -533,12 +536,12 @@ func (ctx *TxParseContext) parseTransactionBody(payload []byte, pos, p0 int, slo // For legacy transactions, hash the full payload if legacy { - if _, err = ctx.Keccak1.Write(payload[pos:p]); err != nil { + if _, err = k1.Write(payload[pos:p]); err != nil { return 0, fmt.Errorf("%w: computing IdHash: %s", ErrParseTxn, err) //nolint } } //ctx.keccak1.Sum(slot.IdHash[:0]) - _, _ = ctx.Keccak1.(io.Reader).Read(slot.IDHash[:32]) + _, _ = k1.(io.Reader).Read(slot.IDHash[:32]) if validateHash != nil { if err := validateHash(slot.IDHash[:32]); err != nil { return p, err @@ -557,25 +560,25 @@ func (ctx *TxParseContext) parseTransactionBody(payload []byte, pos, p0 int, slo // Write len Prefix to the Sighash if sigHashLen < 56 { ctx.buf[0] = byte(sigHashLen) + 192 - if _, err := ctx.Keccak2.Write(ctx.buf[:1]); err != nil { + if _, err := k2.Write(ctx.buf[:1]); err != nil { return 0, fmt.Errorf("%w: computing signHash (hashing len Prefix): %s", ErrParseTxn, err) //nolint } } else { beLen := common.BitLenToByteLen(bits.Len(sigHashLen)) binary.BigEndian.PutUint64(ctx.buf[1:], uint64(sigHashLen)) ctx.buf[8-beLen] = byte(beLen) + 247 - if _, err := ctx.Keccak2.Write(ctx.buf[8-beLen : 9]); err != nil { + if _, err := k2.Write(ctx.buf[8-beLen : 9]); err != nil { return 0, fmt.Errorf("%w: computing signHash (hashing len Prefix): %s", ErrParseTxn, err) //nolint } } - if _, err = ctx.Keccak2.Write(payload[sigHashPos:sigHashEnd]); err != nil { + if _, err = k2.Write(payload[sigHashPos:sigHashEnd]); err != nil { return 0, fmt.Errorf("%w: computing signHash: %s", ErrParseTxn, err) //nolint } if legacy { if chainIDLen > 0 { if chainIDBits <= 7 { ctx.buf[0] = byte(ctx.ChainID.Uint64()) - if _, err := ctx.Keccak2.Write(ctx.buf[:1]); err != nil { + if _, err := k2.Write(ctx.buf[:1]); err != nil { return 0, fmt.Errorf("%w: computing signHash (hashing legacy chainId): %s", ErrParseTxn, err) //nolint } } else { @@ -584,20 +587,20 @@ func (ctx *TxParseContext) parseTransactionBody(payload []byte, pos, p0 int, slo binary.BigEndian.PutUint64(ctx.buf[17:25], ctx.ChainID[1]) binary.BigEndian.PutUint64(ctx.buf[25:33], ctx.ChainID[0]) ctx.buf[32-chainIDLen] = 128 + byte(chainIDLen) - if _, err = ctx.Keccak2.Write(ctx.buf[32-chainIDLen : 33]); err != nil { + if _, err = k2.Write(ctx.buf[32-chainIDLen : 33]); err != nil { return 0, fmt.Errorf("%w: computing signHash (hashing legacy chainId): %s", ErrParseTxn, err) //nolint } } // Encode two zeros ctx.buf[0] = 128 ctx.buf[1] = 128 - if _, err := ctx.Keccak2.Write(ctx.buf[:2]); err != nil { + if _, err := k2.Write(ctx.buf[:2]); err != nil { return 0, fmt.Errorf("%w: computing signHash (hashing zeros after legacy chainId): %s", ErrParseTxn, err) //nolint } } } // Squeeze Sighash - _, _ = ctx.Keccak2.(io.Reader).Read(ctx.Sighash[:32]) + _, _ = k2.(io.Reader).Read(ctx.Sighash[:32]) //ctx.keccak2.Sum(ctx.Sighash[:0]) binary.BigEndian.PutUint64(ctx.Sig[0:8], ctx.R[3]) binary.BigEndian.PutUint64(ctx.Sig[8:16], ctx.R[2]) @@ -613,13 +616,13 @@ func (ctx *TxParseContext) parseTransactionBody(payload []byte, pos, p0 int, slo return 0, fmt.Errorf("%w: recovering sender from signature: %s", ErrParseTxn, err) //nolint } //apply keccak to the public key - ctx.Keccak2.Reset() - if _, err = ctx.Keccak2.Write(ctx.buf[1:65]); err != nil { + k2.Reset() + if _, err = k2.Write(ctx.buf[1:65]); err != nil { return 0, fmt.Errorf("%w: computing sender from public key: %s", ErrParseTxn, err) //nolint } // squeeze the hash of the public key //ctx.keccak2.Sum(ctx.buf[:0]) - _, _ = ctx.Keccak2.(io.Reader).Read(ctx.buf[:32]) + _, _ = k2.(io.Reader).Read(ctx.buf[:32]) //take last 20 bytes as address copy(sender, ctx.buf[12:32]) diff --git a/types/txn_decode.go b/types/txn_decode.go index 5a3279fe2..da699f996 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -10,19 +10,18 @@ import ( "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon-lib/common/u256" "github.com/ledgerwatch/erigon-lib/crypto" + "github.com/ledgerwatch/erigon-lib/crypto/cryptopool" "github.com/ledgerwatch/erigon-lib/rlp" "github.com/ledgerwatch/secp256k1" ) -func (ctx *TxParseContext) parseTransaction2(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (err error) { - if len(payload) == 0 { +func (ctx *TxParseContext) parseTransaction2(decoder *rlp.Decoder, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (err error) { + if decoder.Len() == 0 { return fmt.Errorf("empty rlp") } if ctx.withSender && len(sender) != 20 { return fmt.Errorf("expect sender buffer of len 20") } - decoder := rlp.NewDecoder(payload[pos:]) - dec, tok, err := decoder.ElemDec() if err != nil { return fmt.Errorf("size prefix: %w", err) //nolint @@ -158,9 +157,32 @@ func (ctx *TxParseContext) parseBlobs(dec *rlp.Decoder, slot *TxSlot) (err error func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, sender []byte, validateHash func([]byte) error) (err error) { legacy := slot.Type == LegacyTxType - // Compute transaction hash - ctx.Keccak1.Reset() - ctx.Keccak2.Reset() + k1 := cryptopool.GetLegacyKeccak256() + k2 := cryptopool.GetLegacyKeccak256() + defer cryptopool.ReturnLegacyKeccak256(k1) + defer cryptopool.ReturnLegacyKeccak256(k2) + + if !legacy { + typeByte := []byte{slot.Type} + var tok rlp.Token + if _, err := k1.Write(typeByte); err != nil { + return err + } + if _, err := k2.Write(typeByte); err != nil { + return err + } + dec, tok, err = dec.ElemDec() + if err != nil { + return err + } + if !tok.IsListType() { + return fmt.Errorf("expected list") + } + if _, err := k1.Write(dec.Bytes()); err != nil { + return fmt.Errorf("compute idHash: %w", err) + } + } + if ctx.validateRlp != nil { if err := ctx.validateRlp(slot.Rlp); err != nil { return err @@ -182,6 +204,7 @@ func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, return fmt.Errorf("%s, %d (expected %d)", "invalid chainID", ctx.ChainID.Uint64(), ctx.cfg.ChainID.Uint64()) } } + // Next follows the nonce, which we need to parse err = rlp.ReadElem(dec, rlp.Uint64, &slot.Nonce) if err != nil { @@ -324,17 +347,17 @@ func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, return fmt.Errorf("S: %s", err) //nolint } - if _, err = ctx.Keccak1.Write([]byte{slot.Type}); err != nil { - return fmt.Errorf("computing IdHash: %s", err) //nolint - } - // For legacy transactions, hash the full payload if legacy { - if _, err = ctx.Keccak1.Write(dec.Consumed()); err != nil { + if _, err = k1.Write(dec.Consumed()); err != nil { return fmt.Errorf("computing IdHash: %s", err) //nolint } } - _, _ = ctx.Keccak1.(io.Reader).Read(slot.IDHash[:32]) + + // write the hash to IdHash buffer + //k1.Sum(slot.IDHash[:0]) + _, _ = k1.(io.Reader).Read(slot.IDHash[:32]) + if validateHash != nil { if err := validateHash(slot.IDHash[:32]); err != nil { return err @@ -353,25 +376,25 @@ func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, // Write len Prefix to the Sighash if dec.Offset() < 56 { ctx.buf[0] = byte(dec.Offset()) + 192 - if _, err := ctx.Keccak2.Write(ctx.buf[:1]); err != nil { + if _, err := k2.Write(ctx.buf[:1]); err != nil { return fmt.Errorf("computing signHash (hashing len Prefix): %s", err) //nolint } } else { beLen := common.BitLenToByteLen(bits.Len(uint(dec.Offset()))) binary.BigEndian.PutUint64(ctx.buf[1:], uint64(dec.Offset())) ctx.buf[8-beLen] = byte(beLen) + 247 - if _, err := ctx.Keccak2.Write(ctx.buf[8-beLen : 9]); err != nil { + if _, err := k2.Write(ctx.buf[8-beLen : 9]); err != nil { return fmt.Errorf("computing signHash (hashing len Prefix): %s", err) //nolint } } - if _, err = ctx.Keccak2.Write(dec.Consumed()); err != nil { + if _, err = k2.Write(dec.Consumed()); err != nil { return fmt.Errorf("computing signHash: %s", err) //nolint } if legacy { if chainIDLen > 0 { if chainIDBits <= 7 { ctx.buf[0] = byte(ctx.ChainID.Uint64()) - if _, err := ctx.Keccak2.Write(ctx.buf[:1]); err != nil { + if _, err := k2.Write(ctx.buf[:1]); err != nil { return fmt.Errorf("computing signHash (hashing legacy chainId): %s", err) //nolint } } else { @@ -380,21 +403,21 @@ func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, binary.BigEndian.PutUint64(ctx.buf[17:25], ctx.ChainID[1]) binary.BigEndian.PutUint64(ctx.buf[25:33], ctx.ChainID[0]) ctx.buf[32-chainIDLen] = 128 + byte(chainIDLen) - if _, err = ctx.Keccak2.Write(ctx.buf[32-chainIDLen : 33]); err != nil { + if _, err = k2.Write(ctx.buf[32-chainIDLen : 33]); err != nil { return fmt.Errorf("computing signHash (hashing legacy chainId): %s", err) //nolint } } // Encode two zeros ctx.buf[0] = 128 ctx.buf[1] = 128 - if _, err := ctx.Keccak2.Write(ctx.buf[:2]); err != nil { + if _, err := k2.Write(ctx.buf[:2]); err != nil { return fmt.Errorf("computing signHash (hashing zeros after legacy chainId): %s", err) //nolint } } } // Squeeze Sighash - _, _ = ctx.Keccak2.(io.Reader).Read(ctx.Sighash[:32]) - //ctx.keccak2.Sum(ctx.Sighash[:0]) + _, _ = k2.(io.Reader).Read(ctx.Sighash[:32]) + //k2.Sum(ctx.Sighash[:0]) binary.BigEndian.PutUint64(ctx.Sig[0:8], ctx.R[3]) binary.BigEndian.PutUint64(ctx.Sig[8:16], ctx.R[2]) binary.BigEndian.PutUint64(ctx.Sig[16:24], ctx.R[1]) @@ -409,13 +432,13 @@ func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, return fmt.Errorf("recovering sender from signature: %s", err) //nolint } //apply keccak to the public key - ctx.Keccak2.Reset() - if _, err = ctx.Keccak2.Write(ctx.buf[1:65]); err != nil { + k2.Reset() + if _, err = k2.Write(ctx.buf[1:65]); err != nil { return fmt.Errorf("computing sender from public key: %s", err) //nolint } // squeeze the hash of the public key - //ctx.keccak2.Sum(ctx.buf[:0]) - _, _ = ctx.Keccak2.(io.Reader).Read(ctx.buf[:32]) + //k2.Sum(ctx.buf[:0]) + _, _ = k2.(io.Reader).Read(ctx.buf[:32]) //take last 20 bytes as address copy(sender, ctx.buf[12:32]) diff --git a/types/txn_packets.go b/types/txn_packets.go index 3f7ed58bd..17d0ba1c5 100644 --- a/types/txn_packets.go +++ b/types/txn_packets.go @@ -192,6 +192,28 @@ func ParseTransactions(payload []byte, pos int, ctx *TxParseContext, txSlots *Tx return pos, nil } +func DecodeTransactions(dec *rlp.Decoder, ctx *TxParseContext, txSlots *TxSlots, validateHash func([]byte) error) (err error) { + i := 0 + err = dec.ForList(func(d *rlp.Decoder) error { + txSlots.Resize(uint(i + 1)) + txSlots.Txs[i] = &TxSlot{} + err = ctx.parseTransaction2(d, txSlots.Txs[i], txSlots.Senders.At(i), true /* hasEnvelope */, true /* wrappedWithBlobs */, validateHash) + if err != nil { + if errors.Is(err, ErrRejected) { + txSlots.Resize(uint(i)) + return nil + } + return err + } + i = i + 1 + return nil + }) + if err != nil { + return fmt.Errorf("%w: %w", ErrParseTxn, err) + } + return nil +} + func ParsePooledTransactions66(payload []byte, pos int, ctx *TxParseContext, txSlots *TxSlots, validateHash func([]byte) error) (requestID uint64, newPos int, err error) { p, _, err := rlp.List(payload, pos) if err != nil { From 203d9c3c0ea76f3563022b439d8ef2b73ce83c85 Mon Sep 17 00:00:00 2001 From: a Date: Wed, 6 Sep 2023 23:40:14 -0500 Subject: [PATCH 19/34] wip --- rlp/parse.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rlp/parse.go b/rlp/parse.go index 449277f09..23ffc4637 100644 --- a/rlp/parse.go +++ b/rlp/parse.go @@ -37,7 +37,7 @@ func IsRLPError(err error) bool { return errors.Is(err, ErrBase) } // BeInt parses Big Endian representation of an integer from given payload at given position func BeInt(payload []byte, pos, length int) (int, error) { var r int - if pos+length > len(payload) { + if pos+length >= len(payload) { return 0, ErrUnexpectedEOF } if length > 0 && payload[pos] == 0 { From 0297db395584b3a07316f68985de41f0e3b8239d Mon Sep 17 00:00:00 2001 From: a Date: Wed, 6 Sep 2023 23:42:25 -0500 Subject: [PATCH 20/34] return p on error --- types/txn.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/types/txn.go b/types/txn.go index 4931d747a..db1324c16 100644 --- a/types/txn.go +++ b/types/txn.go @@ -148,12 +148,10 @@ func PeekTransactionType(serialized []byte) (byte, error) { // 0x03 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s] // 0x03 ||[[chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s], blobs, commitments, proofs] func (ctx *TxParseContext) ParseTransaction(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (p int, err error) { - //dec := rlp.NewDecoder(payload) p, err = ctx.parseTransaction(payload, pos, slot, sender, hasEnvelope, wrappedWithBlobs, validateHash) if err != nil { - return 0, fmt.Errorf("%w: %w", ErrParseTxn, err) + return p, fmt.Errorf("%w: %w", ErrParseTxn, err) } - //return dec.Offset() + pos, nil return p, nil } func (ctx *TxParseContext) parseTransaction(payload []byte, pos int, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (p int, err error) { From 8ecc3ff381cde21249caa5c3546b754a6839a455 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 02:31:22 -0500 Subject: [PATCH 21/34] aok --- rlp/decoder.go | 35 +++++++++-- rlp/parse.go | 2 +- txpool/fetch.go | 2 +- types/txn_decode.go | 117 +++++++++++++++++++----------------- types/txn_packets.go | 23 ------- types/txn_packets_decode.go | 30 +++++++++ types/txn_packets_test.go | 5 +- types/txn_test.go | 35 +++++------ 8 files changed, 142 insertions(+), 107 deletions(-) create mode 100644 types/txn_packets_decode.go diff --git a/rlp/decoder.go b/rlp/decoder.go index 8baf0d52a..912d51c5d 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -15,12 +15,16 @@ func NewDecoder(buf []byte) *Decoder { } } +func (d *Decoder) String() string { + return fmt.Sprintf(`left=%x pos=%d`, d.buf.Bytes(), d.buf.off) +} + func (d *Decoder) Consumed() []byte { return d.buf.u[:d.buf.off] } func (d *Decoder) Underlying() []byte { - return d.buf.u + return d.buf.Underlying() } func (d *Decoder) Len() int { @@ -39,6 +43,18 @@ func (d *Decoder) ReadByte() (n byte, err error) { return d.buf.ReadByte() } +func (d *Decoder) PeekByte() (n byte, err error) { + return d.buf.PeekByte() +} + +func (d *Decoder) PeekToken() (Token, error) { + prefix, err := d.PeekByte() + if err != nil { + return TokenUnknown, err + } + return identifyToken(prefix), nil +} + func (d *Decoder) ElemDec() (*Decoder, Token, error) { a, t, err := d.Elem() return NewDecoder(a), t, err @@ -86,6 +102,7 @@ func (d *Decoder) Elem() ([]byte, Token, error) { default: return nil, token, fmt.Errorf("%w: unknown token", ErrDecode) } + //log.Printf("%x %s\n", buf, token) if err != nil { return nil, token, err } @@ -118,13 +135,17 @@ func (d *Decoder) ForList(fn func(*Decoder) error) error { switch token { case TokenShortList, TokenLongList: dec := NewDecoder(buf) - for dec.buf.Len() > 0 { + for { + if dec.buf.Len() == 0 { + return nil + } err := fn(dec) if err != nil { return err } + // reset the byte + dec = NewDecoder(dec.Bytes()) } - return nil default: return fmt.Errorf("%w: ForList on non-list", ErrDecode) } @@ -169,7 +190,9 @@ func (b *buf) Offset() int { return b.off } -func (b *buf) Bytes() []byte { return b.u[b.off:] } +func (b *buf) Bytes() []byte { + return b.u[b.off:] +} func (b *buf) String() string { if b == nil { @@ -180,3 +203,7 @@ func (b *buf) String() string { } func (b *buf) Len() int { return len(b.u) - b.off } + +func (b *buf) Underlying() []byte { + return b.u +} diff --git a/rlp/parse.go b/rlp/parse.go index 23ffc4637..449277f09 100644 --- a/rlp/parse.go +++ b/rlp/parse.go @@ -37,7 +37,7 @@ func IsRLPError(err error) bool { return errors.Is(err, ErrBase) } // BeInt parses Big Endian representation of an integer from given payload at given position func BeInt(payload []byte, pos, length int) (int, error) { var r int - if pos+length >= len(payload) { + if pos+length > len(payload) { return 0, ErrUnexpectedEOF } if length > 0 && payload[pos] == 0 { diff --git a/txpool/fetch.go b/txpool/fetch.go index b319f638b..355916f54 100644 --- a/txpool/fetch.go +++ b/txpool/fetch.go @@ -326,7 +326,7 @@ func (f *Fetch) handleInboundMessage(ctx context.Context, req *sentry.InboundMes switch req.Id { case sentry.MessageId_TRANSACTIONS_66: if err := f.threadSafeParsePooledTxn(func(parseContext *types2.TxParseContext) error { - if _, err := types2.ParseTransactions(req.Data, 0, parseContext, &txs, func(hash []byte) error { + if err := types2.DecodeTransactions(rlp.NewDecoder(req.Data), parseContext, &txs, func(hash []byte) error { known, err := f.pool.IdHashKnown(tx, hash) if err != nil { return err diff --git a/types/txn_decode.go b/types/txn_decode.go index da699f996..d4412284c 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -15,35 +15,47 @@ import ( "github.com/ledgerwatch/secp256k1" ) -func (ctx *TxParseContext) parseTransaction2(decoder *rlp.Decoder, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (err error) { +func (ctx *TxParseContext) DecodeTransaction(decoder *rlp.Decoder, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (err error) { + err = ctx.decodeTransaction(decoder, slot, sender, hasEnvelope, wrappedWithBlobs, validateHash) + if err != nil { + return fmt.Errorf("%w: %w", ErrParseTxn, err) + } + return nil +} + +func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (err error) { if decoder.Len() == 0 { return fmt.Errorf("empty rlp") } if ctx.withSender && len(sender) != 20 { return fmt.Errorf("expect sender buffer of len 20") } - dec, tok, err := decoder.ElemDec() - if err != nil { - return fmt.Errorf("size prefix: %w", err) //nolint + tok, err := decoder.PeekToken() + dec := decoder + if tok == rlp.TokenDecimal { + if hasEnvelope { + return fmt.Errorf("expected envelope in the payload, got %s", tok) + } + } else if hasEnvelope && tok.IsBlobType() || tok.IsListType() { + dec, tok, err = decoder.ElemDec() + if err != nil { + return fmt.Errorf("size prefix: %w", err) //nolint + } } if dec.Len() == 0 { return fmt.Errorf("transaction must be either 1 list or 1 string") } - if dec.Len() == 1 && !tok.IsListType() { - if hasEnvelope { - return fmt.Errorf("expected envelope in the payload, got %x", dec.Bytes()) - } - } - - // Legacy transactions have list Prefix, whereas EIP-2718 transactions have string Prefix - // therefore we assign the first returned value of Prefix function (list) to legacy variable switch { - case tok.IsListType(): - slot.Rlp = append(make([]byte, 0, dec.Len()), dec.Bytes()...) + default: + return fmt.Errorf("expected list or blob token, got %s", tok) + case tok.IsListType(): // Legacy transactions have list Prefix, + slot.Rlp = append([]byte{}, decoder.Consumed()...) slot.Size = uint32(len(slot.Rlp)) slot.Type = LegacyTxType - case tok == rlp.TokenDecimal: + case tok == rlp.TokenDecimal: // EIP-2718 transactions have string Prefix + slot.Rlp = append([]byte{}, dec.Bytes()...) + slot.Size = uint32(len(slot.Rlp)) slot.Type, err = dec.ReadByte() if err != nil { return fmt.Errorf("couldnt read txn type: %w", err) @@ -51,20 +63,9 @@ func (ctx *TxParseContext) parseTransaction2(decoder *rlp.Decoder, slot *TxSlot, if slot.Type > BlobTxType { return fmt.Errorf("unknown transaction type: %d", slot.Type) } - dec, tok, err = dec.ElemDec() - if err != nil { - return err - } - if !tok.IsListType() { - return fmt.Errorf("expected list token") - } - slot.Rlp = append(make([]byte, 0, dec.Len()), dec.Bytes()...) - slot.Size = uint32(len(slot.Rlp)) - default: - return fmt.Errorf("expected list or decimal token") } - bodyDecoder := dec + bodyDecoder := rlp.NewDecoder(dec.Bytes()) // if its a blob transaction, we actually need to enter a nested list, since its [tx_payload_body, blobs, commitments, proofs] if slot.Type == BlobTxType && wrappedWithBlobs { bodyDecoder, _, err = dec.ElemDec() @@ -73,21 +74,21 @@ func (ctx *TxParseContext) parseTransaction2(decoder *rlp.Decoder, slot *TxSlot, } } - err = ctx.parseTransactionBody2(bodyDecoder, slot, sender, validateHash) + err = ctx.decodeTransactionBody(bodyDecoder, decoder, slot, sender, validateHash) if err != nil { - return err + return fmt.Errorf("txn body: %w", err) } - // so its a blob transaction and we need to do the extra stuff... + // so its a blob transaction and we need to do the extra stuff, otherwise we are done if slot.Type == BlobTxType && wrappedWithBlobs { - if err := ctx.parseBlobs(dec, slot); err != nil { - return err + if err := ctx.decodeBlobs(dec, slot); err != nil { + return fmt.Errorf("decode blobs: %w", err) } - if err := ctx.parseCommitments(dec, slot); err != nil { - return err + if err := ctx.decodeCommitments(dec, slot); err != nil { + return fmt.Errorf("decode commitments: %w", err) } - if err := ctx.parseProofs(dec, slot); err != nil { - return err + if err := ctx.decodeProofs(dec, slot); err != nil { + return fmt.Errorf("decode proofs: %w", err) } if len(slot.Blobs) != len(slot.Commitments) { return fmt.Errorf("blob count != commitment count") @@ -102,7 +103,7 @@ func (ctx *TxParseContext) parseTransaction2(decoder *rlp.Decoder, slot *TxSlot, return err } -func (ctx *TxParseContext) parseCommitments(dec *rlp.Decoder, slot *TxSlot) (err error) { +func (ctx *TxParseContext) decodeCommitments(dec *rlp.Decoder, slot *TxSlot) (err error) { err = dec.ForList(func(d *rlp.Decoder) error { var blob gokzg4844.KZGCommitment blobSlice := blob[:] @@ -120,7 +121,7 @@ func (ctx *TxParseContext) parseCommitments(dec *rlp.Decoder, slot *TxSlot) (err return nil } -func (ctx *TxParseContext) parseProofs(dec *rlp.Decoder, slot *TxSlot) (err error) { +func (ctx *TxParseContext) decodeProofs(dec *rlp.Decoder, slot *TxSlot) (err error) { err = dec.ForList(func(d *rlp.Decoder) error { var blob gokzg4844.KZGProof blobSlice := blob[:] @@ -137,7 +138,7 @@ func (ctx *TxParseContext) parseProofs(dec *rlp.Decoder, slot *TxSlot) (err erro return nil } -func (ctx *TxParseContext) parseBlobs(dec *rlp.Decoder, slot *TxSlot) (err error) { +func (ctx *TxParseContext) decodeBlobs(dec *rlp.Decoder, slot *TxSlot) (err error) { err = dec.ForList(func(d *rlp.Decoder) error { var blob []byte err := rlp.ReadElem(dec, rlp.Bytes, &blob) @@ -154,7 +155,7 @@ func (ctx *TxParseContext) parseBlobs(dec *rlp.Decoder, slot *TxSlot) (err error return nil } -func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, sender []byte, validateHash func([]byte) error) (err error) { +func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.Decoder, slot *TxSlot, sender []byte, validateHash func([]byte) error) (err error) { legacy := slot.Type == LegacyTxType k1 := cryptopool.GetLegacyKeccak256() @@ -162,23 +163,19 @@ func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, defer cryptopool.ReturnLegacyKeccak256(k1) defer cryptopool.ReturnLegacyKeccak256(k2) + k1.Reset() + k2.Reset() + + // for computing tx hash if !legacy { typeByte := []byte{slot.Type} - var tok rlp.Token if _, err := k1.Write(typeByte); err != nil { return err } if _, err := k2.Write(typeByte); err != nil { return err } - dec, tok, err = dec.ElemDec() - if err != nil { - return err - } - if !tok.IsListType() { - return fmt.Errorf("expected list") - } - if _, err := k1.Write(dec.Bytes()); err != nil { + if _, err := k1.Write(parent.Consumed()); err != nil { return fmt.Errorf("compute idHash: %w", err) } } @@ -189,6 +186,8 @@ func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, } } + // signing hash data starts here + sigHashPos := dec.Offset() if !legacy { err = rlp.ReadElem(dec, rlp.Uint256, &ctx.ChainID) if err != nil { @@ -292,6 +291,8 @@ func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, }) } // This is where the data for Sighash ends + sigHashEnd := dec.Offset() + sigHashLen := uint(sigHashEnd - sigHashPos) // Next follows V of the signature var vByte byte @@ -319,7 +320,10 @@ func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, chainIDLen = 1 } else { chainIDLen = common.BitLenToByteLen(chainIDBits) // It is always < 56 bytes + sigHashLen++ // For chainId len Prefix } + sigHashLen += uint(chainIDLen) // For chainId + sigHashLen += 2 // For two extra zeros ctx.DeriveChainID.Sub(&ctx.V, &ctx.ChainIDMul) vByte = byte(ctx.DeriveChainID.Sub(&ctx.DeriveChainID, u256.N8).Uint64() - 27) } @@ -347,9 +351,10 @@ func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, return fmt.Errorf("S: %s", err) //nolint } - // For legacy transactions, hash the full payload + // For legacy transactions, just hash the full payload if legacy { - if _, err = k1.Write(dec.Consumed()); err != nil { + u := parent.Consumed() + if _, err = k1.Write(u); err != nil { return fmt.Errorf("computing IdHash: %s", err) //nolint } } @@ -374,20 +379,20 @@ func (ctx *TxParseContext) parseTransactionBody2(dec *rlp.Decoder, slot *TxSlot, // Computing sigHash (hash used to recover sender from the signature) // Write len Prefix to the Sighash - if dec.Offset() < 56 { - ctx.buf[0] = byte(dec.Offset()) + 192 + if sigHashLen < 56 { + ctx.buf[0] = byte(sigHashLen) + 192 if _, err := k2.Write(ctx.buf[:1]); err != nil { return fmt.Errorf("computing signHash (hashing len Prefix): %s", err) //nolint } } else { - beLen := common.BitLenToByteLen(bits.Len(uint(dec.Offset()))) - binary.BigEndian.PutUint64(ctx.buf[1:], uint64(dec.Offset())) + beLen := common.BitLenToByteLen(bits.Len(uint(sigHashLen))) + binary.BigEndian.PutUint64(ctx.buf[1:], uint64(sigHashLen)) ctx.buf[8-beLen] = byte(beLen) + 247 if _, err := k2.Write(ctx.buf[8-beLen : 9]); err != nil { return fmt.Errorf("computing signHash (hashing len Prefix): %s", err) //nolint } } - if _, err = k2.Write(dec.Consumed()); err != nil { + if _, err = k2.Write(dec.Underlying()[sigHashPos:sigHashEnd]); err != nil { return fmt.Errorf("computing signHash: %s", err) //nolint } if legacy { diff --git a/types/txn_packets.go b/types/txn_packets.go index 17d0ba1c5..630341752 100644 --- a/types/txn_packets.go +++ b/types/txn_packets.go @@ -191,29 +191,6 @@ func ParseTransactions(payload []byte, pos int, ctx *TxParseContext, txSlots *Tx } return pos, nil } - -func DecodeTransactions(dec *rlp.Decoder, ctx *TxParseContext, txSlots *TxSlots, validateHash func([]byte) error) (err error) { - i := 0 - err = dec.ForList(func(d *rlp.Decoder) error { - txSlots.Resize(uint(i + 1)) - txSlots.Txs[i] = &TxSlot{} - err = ctx.parseTransaction2(d, txSlots.Txs[i], txSlots.Senders.At(i), true /* hasEnvelope */, true /* wrappedWithBlobs */, validateHash) - if err != nil { - if errors.Is(err, ErrRejected) { - txSlots.Resize(uint(i)) - return nil - } - return err - } - i = i + 1 - return nil - }) - if err != nil { - return fmt.Errorf("%w: %w", ErrParseTxn, err) - } - return nil -} - func ParsePooledTransactions66(payload []byte, pos int, ctx *TxParseContext, txSlots *TxSlots, validateHash func([]byte) error) (requestID uint64, newPos int, err error) { p, _, err := rlp.List(payload, pos) if err != nil { diff --git a/types/txn_packets_decode.go b/types/txn_packets_decode.go new file mode 100644 index 000000000..34721f825 --- /dev/null +++ b/types/txn_packets_decode.go @@ -0,0 +1,30 @@ +package types + +import ( + "errors" + "fmt" + + "github.com/ledgerwatch/erigon-lib/rlp" +) + +func DecodeTransactions(dec *rlp.Decoder, ctx *TxParseContext, txSlots *TxSlots, validateHash func([]byte) error) (err error) { + i := 0 + err = dec.ForList(func(d *rlp.Decoder) error { + txSlots.Resize(uint(i + 1)) + txSlots.Txs[i] = &TxSlot{} + err = ctx.DecodeTransaction(d, txSlots.Txs[i], txSlots.Senders.At(i), true /* hasEnvelope */, true /* wrappedWithBlobs */, validateHash) + if err != nil { + if errors.Is(err, ErrRejected) { + txSlots.Resize(uint(i)) + return nil + } + return fmt.Errorf("elem: %w", err) + } + i = i + 1 + return nil + }) + if err != nil { + return err + } + return nil +} diff --git a/types/txn_packets_test.go b/types/txn_packets_test.go index a21c523fd..9d23ac291 100644 --- a/types/txn_packets_test.go +++ b/types/txn_packets_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/require" "github.com/ledgerwatch/erigon-lib/common/hexutility" + "github.com/ledgerwatch/erigon-lib/rlp" ) var hashParseTests = []struct { @@ -205,7 +206,7 @@ func TestTransactionsPacket(t *testing.T) { ctx := NewTxParseContext(*uint256.NewInt(tt.chainID)) slots := &TxSlots{} - _, err := ParseTransactions(encodeBuf, 0, ctx, slots, nil) + err := DecodeTransactions(rlp.NewDecoder(encodeBuf), ctx, slots, nil) require.NoError(err) require.Equal(len(tt.txs), len(slots.Txs)) for i, txn := range tt.txs { @@ -223,7 +224,7 @@ func TestTransactionsPacket(t *testing.T) { chainID := uint256.NewInt(tt.chainID) ctx := NewTxParseContext(*chainID) slots := &TxSlots{} - _, err := ParseTransactions(encodeBuf, 0, ctx, slots, func(bytes []byte) error { return ErrRejected }) + err := DecodeTransactions(rlp.NewDecoder(encodeBuf), ctx, slots, func(bytes []byte) error { return ErrRejected }) require.NoError(err) require.Equal(0, len(slots.Txs)) require.Equal(0, slots.Senders.Len()) diff --git a/types/txn_test.go b/types/txn_test.go index b5e992d49..0fe0b15cf 100644 --- a/types/txn_test.go +++ b/types/txn_test.go @@ -17,7 +17,6 @@ package types import ( - "bytes" "crypto/rand" "strconv" "testing" @@ -29,6 +28,7 @@ import ( "github.com/ledgerwatch/erigon-lib/common/fixedgas" "github.com/ledgerwatch/erigon-lib/common/hexutility" + "github.com/ledgerwatch/erigon-lib/rlp" ) func TestParseTransactionRLP(t *testing.T) { @@ -42,26 +42,21 @@ func TestParseTransactionRLP(t *testing.T) { tt := tt t.Run(strconv.Itoa(i), func(t *testing.T) { payload := hexutility.MustDecodeHex(tt.PayloadStr) - parseEnd, err := ctx.ParseTransaction(payload, 0, tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) - require.NoError(err) - require.Equal(len(payload), parseEnd) + dec := rlp.NewDecoder(payload) + err := ctx.DecodeTransaction(dec, tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + require.NoError(err, tt.PayloadStr) + require.Equal(len(payload), dec.Offset()) if tt.SignHashStr != "" { signHash := hexutility.MustDecodeHex(tt.SignHashStr) - if !bytes.Equal(signHash, ctx.Sighash[:]) { - t.Errorf("signHash expected %x, got %x", signHash, ctx.Sighash) - } + assert.EqualValues(t, signHash, ctx.Sighash[:], "sighash") } if tt.IdHashStr != "" { idHash := hexutility.MustDecodeHex(tt.IdHashStr) - if !bytes.Equal(idHash, tx.IDHash[:]) { - t.Errorf("IdHash expected %x, got %x", idHash, tx.IDHash) - } + assert.EqualValues(t, idHash, tx.IDHash[:], "idhash") } if tt.SenderStr != "" { expectSender := hexutility.MustDecodeHex(tt.SenderStr) - if !bytes.Equal(expectSender, txSender[:]) { - t.Errorf("expectSender expected %x, got %x", expectSender, txSender) - } + assert.EqualValues(t, expectSender, txSender[:], "sender") } require.Equal(tt.Nonce, tx.Nonce) }) @@ -77,19 +72,19 @@ func TestTransactionSignatureValidity1(t *testing.T) { tx, txSender := &TxSlot{}, [20]byte{} validTxn := hexutility.MustDecodeHex("f83f800182520894095e7baea6a6c7c4c2dfeb977efac326af552d870b801ba048b55bfa915ac795c431978d8a6a992b628d557da5ff759b307d495a3664935301") - _, err := ctx.ParseTransaction(validTxn, 0, tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + err := ctx.DecodeTransaction(rlp.NewDecoder(validTxn), tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) assert.NoError(t, err) preEip2Txn := hexutility.MustDecodeHex("f85f800182520894095e7baea6a6c7c4c2dfeb977efac326af552d870b801ba048b55bfa915ac795c431978d8a6a992b628d557da5ff759b307d495a36649353a07fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a1") - _, err = ctx.ParseTransaction(preEip2Txn, 0, tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + err = ctx.DecodeTransaction(rlp.NewDecoder(validTxn), tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) assert.NoError(t, err) // Now enforce EIP-2 ctx.WithAllowPreEip2s(false) - _, err = ctx.ParseTransaction(validTxn, 0, tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + err = ctx.DecodeTransaction(rlp.NewDecoder(validTxn), tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) assert.NoError(t, err) - _, err = ctx.ParseTransaction(preEip2Txn, 0, tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + err = ctx.DecodeTransaction(rlp.NewDecoder(preEip2Txn), tx, txSender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) assert.Error(t, err) } @@ -98,13 +93,13 @@ func TestTransactionSignatureValidity2(t *testing.T) { chainId := new(uint256.Int).SetUint64(5) ctx := NewTxParseContext(*chainId) slot, sender := &TxSlot{}, [20]byte{} - rlp := hexutility.MustDecodeHex("02f8720513844190ab00848321560082520894cab441d2f45a3fee83d15c6b6b6c36a139f55b6288054607fc96a6000080c001a0dffe4cb5651e663d0eac8c4d002de734dd24db0f1109b062d17da290a133cc02a0913fb9f53f7a792bcd9e4d7cced1b8545d1ab82c77432b0bc2e9384ba6c250c5") - _, err := ctx.ParseTransaction(rlp, 0, slot, sender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + xs := hexutility.MustDecodeHex("02f8720513844190ab00848321560082520894cab441d2f45a3fee83d15c6b6b6c36a139f55b6288054607fc96a6000080c001a0dffe4cb5651e663d0eac8c4d002de734dd24db0f1109b062d17da290a133cc02a0913fb9f53f7a792bcd9e4d7cced1b8545d1ab82c77432b0bc2e9384ba6c250c5") + err := ctx.DecodeTransaction(rlp.NewDecoder(xs), slot, sender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) assert.Error(t, err) // Only legacy transactions can happen before EIP-2 ctx.WithAllowPreEip2s(true) - _, err = ctx.ParseTransaction(rlp, 0, slot, sender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + err = ctx.DecodeTransaction(rlp.NewDecoder(xs), slot, sender[:], false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) assert.Error(t, err) } From bf3229a667c1995e7489f93957132a9c88eea0c2 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 03:57:26 -0500 Subject: [PATCH 22/34] asdf --- rlp/decoder.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rlp/decoder.go b/rlp/decoder.go index 912d51c5d..e8cb965f6 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -27,8 +27,8 @@ func (d *Decoder) Underlying() []byte { return d.buf.Underlying() } -func (d *Decoder) Len() int { - return d.buf.Len() +func (d *Decoder) Empty() bool { + return d.buf.empty() } func (d *Decoder) Offset() int { From 33be682df4b250dc5039f1c684905fd9af43074d Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 03:58:32 -0500 Subject: [PATCH 23/34] head empty --- types/txn_decode.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/types/txn_decode.go b/types/txn_decode.go index d4412284c..27cca6277 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -24,7 +24,7 @@ func (ctx *TxParseContext) DecodeTransaction(decoder *rlp.Decoder, slot *TxSlot, } func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (err error) { - if decoder.Len() == 0 { + if decoder.Empty() { return fmt.Errorf("empty rlp") } if ctx.withSender && len(sender) != 20 { @@ -43,7 +43,7 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, } } - if dec.Len() == 0 { + if dec.Empty() { return fmt.Errorf("transaction must be either 1 list or 1 string") } switch { From 721b71b8438cb906534a3d4115c9f006f7786108 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 04:55:33 -0500 Subject: [PATCH 24/34] wip --- rlp/decoder.go | 6 +++++ types/txn_decode.go | 58 +++++++++++++++++++++++++++++++-------------- types/txn_test.go | 10 ++++---- 3 files changed, 52 insertions(+), 22 deletions(-) diff --git a/rlp/decoder.go b/rlp/decoder.go index e8cb965f6..36e90a923 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -47,6 +47,12 @@ func (d *Decoder) PeekByte() (n byte, err error) { return d.buf.PeekByte() } +func (d *Decoder) Fork() *Decoder { + return &Decoder{ + buf: newBuf(d.Bytes(), 0), + } +} + func (d *Decoder) PeekToken() (Token, error) { prefix, err := d.PeekByte() if err != nil { diff --git a/types/txn_decode.go b/types/txn_decode.go index 27cca6277..02aa4b663 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -15,6 +15,15 @@ import ( "github.com/ledgerwatch/secp256k1" ) +// DecodeTransaction extracts all the information from the transactions's payload (RLP) necessary to build TxSlot. +// It also performs syntactic validation of the transactions. +// wrappedWithBlobs means that for blob (type 3) transactions the full version with blobs/commitments/proofs is expected +// (see https://eips.ethereum.org/EIPS/eip-4844#networking). +// [ ] || [ nonce, price, limit, to, value, data, y,r,s] +// 0x01 || [chain_id, nonce, price, limit, to, value, data, access_list, y,r,s] +// 0x02 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, y,r,s] +// 0x03 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s] +// 0x03 ||[[chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s], blobs, commitments, proofs] func (ctx *TxParseContext) DecodeTransaction(decoder *rlp.Decoder, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (err error) { err = ctx.decodeTransaction(decoder, slot, sender, hasEnvelope, wrappedWithBlobs, validateHash) if err != nil { @@ -30,32 +39,40 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, if ctx.withSender && len(sender) != 20 { return fmt.Errorf("expect sender buffer of len 20") } - tok, err := decoder.PeekToken() - dec := decoder - if tok == rlp.TokenDecimal { + peektok, err := decoder.PeekToken() + // means that this is non-enveloped non-legacy transaction + if peektok == rlp.TokenDecimal { if hasEnvelope { - return fmt.Errorf("expected envelope in the payload, got %s", tok) + return fmt.Errorf("expected envelope in the payload, got %s", peektok) } - } else if hasEnvelope && tok.IsBlobType() || tok.IsListType() { - dec, tok, err = decoder.ElemDec() + } + + if decoder.Empty() { + return fmt.Errorf("transaction must be either 1 list or 1 string") + } + + // if is blob type, it means that its an envelope, so we need to get out of that + if peektok.IsBlobType() { + decoder, _, err = decoder.ElemDec() if err != nil { return fmt.Errorf("size prefix: %w", err) //nolint } } + dec := decoder - if dec.Empty() { - return fmt.Errorf("transaction must be either 1 list or 1 string") - } switch { default: - return fmt.Errorf("expected list or blob token, got %s", tok) - case tok.IsListType(): // Legacy transactions have list Prefix, + return fmt.Errorf("expected list or blob token, got %s", peektok) + case peektok.IsListType(): // Legacy transactions have list Prefix, + dec, peektok, err = decoder.ElemDec() + if err != nil { + return fmt.Errorf("size prefix: %w", err) //nolint + } slot.Rlp = append([]byte{}, decoder.Consumed()...) - slot.Size = uint32(len(slot.Rlp)) slot.Type = LegacyTxType - case tok == rlp.TokenDecimal: // EIP-2718 transactions have string Prefix - slot.Rlp = append([]byte{}, dec.Bytes()...) - slot.Size = uint32(len(slot.Rlp)) + case peektok.IsBlobType() || peektok == rlp.TokenDecimal: // EIP-2718 transactions have string Prefix + slot.Rlp = append([]byte{}, decoder.Bytes()...) + // at this point, the next byte is the type slot.Type, err = dec.ReadByte() if err != nil { return fmt.Errorf("couldnt read txn type: %w", err) @@ -63,9 +80,16 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, if slot.Type > BlobTxType { return fmt.Errorf("unknown transaction type: %d", slot.Type) } + // now enter the list, since that is what we are in front of now + dec, _, err = dec.ElemDec() + if err != nil { + return fmt.Errorf("extract txn body: %w", err) //nolint + } + } bodyDecoder := rlp.NewDecoder(dec.Bytes()) + // if its a blob transaction, we actually need to enter a nested list, since its [tx_payload_body, blobs, commitments, proofs] if slot.Type == BlobTxType && wrappedWithBlobs { bodyDecoder, _, err = dec.ElemDec() @@ -100,6 +124,7 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, return fmt.Errorf("blob count != blob hash count") } } + slot.Size = uint32(decoder.Offset()) return err } @@ -169,9 +194,6 @@ func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.D // for computing tx hash if !legacy { typeByte := []byte{slot.Type} - if _, err := k1.Write(typeByte); err != nil { - return err - } if _, err := k2.Write(typeByte); err != nil { return err } diff --git a/types/txn_test.go b/types/txn_test.go index 0fe0b15cf..311c27474 100644 --- a/types/txn_test.go +++ b/types/txn_test.go @@ -200,9 +200,10 @@ func TestBlobTxParsing(t *testing.T) { require.NoError(t, err) assert.Equal(t, BlobTxType, txType) - p, err := ctx.ParseTransaction(bodyEnvelope, 0, &thinTx, nil, hasEnvelope, wrappedWithBlobs, nil) + dec := rlp.NewDecoder(bodyEnvelope) + err = ctx.DecodeTransaction(dec, &thinTx, nil, hasEnvelope, wrappedWithBlobs, nil) require.NoError(t, err) - assert.Equal(t, len(bodyEnvelope), p) + assert.Equal(t, len(bodyEnvelope), dec.Offset()) assert.Equal(t, len(bodyEnvelope), int(thinTx.Size)) assert.Equal(t, bodyEnvelope[3:], thinTx.Rlp) assert.Equal(t, BlobTxType, thinTx.Type) @@ -253,9 +254,10 @@ func TestBlobTxParsing(t *testing.T) { require.NoError(t, err) assert.Equal(t, BlobTxType, txType) - p, err = ctx.ParseTransaction(wrapperRlp, 0, &fatTx, nil, hasEnvelope, wrappedWithBlobs, nil) + dec = rlp.NewDecoder(wrapperRlp) + err = ctx.DecodeTransaction(dec, &fatTx, nil, hasEnvelope, wrappedWithBlobs, nil) require.NoError(t, err) - assert.Equal(t, len(wrapperRlp), p) + assert.Equal(t, len(wrapperRlp), dec.Offset()) assert.Equal(t, len(wrapperRlp), int(fatTx.Size)) assert.Equal(t, wrapperRlp, fatTx.Rlp) assert.Equal(t, BlobTxType, fatTx.Type) From 13c0fdc4cb2a3283cb84ee028c872aa10532de85 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 05:01:10 -0500 Subject: [PATCH 25/34] anotha one --- types/txn_decode.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/types/txn_decode.go b/types/txn_decode.go index 02aa4b663..cf94b78ca 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -33,6 +33,7 @@ func (ctx *TxParseContext) DecodeTransaction(decoder *rlp.Decoder, slot *TxSlot, } func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (err error) { + od := decoder if decoder.Empty() { return fmt.Errorf("empty rlp") } @@ -124,7 +125,7 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, return fmt.Errorf("blob count != blob hash count") } } - slot.Size = uint32(decoder.Offset()) + slot.Size = uint32(od.Offset()) return err } From f4d2a1ccc3f79e7a813f484bddae0855389f1076 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 05:18:27 -0500 Subject: [PATCH 26/34] typo --- types/txn_decode.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/types/txn_decode.go b/types/txn_decode.go index cf94b78ca..1033a3c1b 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -89,8 +89,7 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, } - bodyDecoder := rlp.NewDecoder(dec.Bytes()) - + bodyDecoder := dec // if its a blob transaction, we actually need to enter a nested list, since its [tx_payload_body, blobs, commitments, proofs] if slot.Type == BlobTxType && wrappedWithBlobs { bodyDecoder, _, err = dec.ElemDec() @@ -133,7 +132,7 @@ func (ctx *TxParseContext) decodeCommitments(dec *rlp.Decoder, slot *TxSlot) (er err = dec.ForList(func(d *rlp.Decoder) error { var blob gokzg4844.KZGCommitment blobSlice := blob[:] - err := rlp.ReadElem(dec, rlp.BytesExact, &blobSlice) + err := rlp.ReadElem(d, rlp.BytesExact, &blobSlice) if err != nil { return err } @@ -151,7 +150,7 @@ func (ctx *TxParseContext) decodeProofs(dec *rlp.Decoder, slot *TxSlot) (err err err = dec.ForList(func(d *rlp.Decoder) error { var blob gokzg4844.KZGProof blobSlice := blob[:] - err := rlp.ReadElem(dec, rlp.BytesExact, &blobSlice) + err := rlp.ReadElem(d, rlp.BytesExact, &blobSlice) if err != nil { return err } @@ -167,7 +166,7 @@ func (ctx *TxParseContext) decodeProofs(dec *rlp.Decoder, slot *TxSlot) (err err func (ctx *TxParseContext) decodeBlobs(dec *rlp.Decoder, slot *TxSlot) (err error) { err = dec.ForList(func(d *rlp.Decoder) error { var blob []byte - err := rlp.ReadElem(dec, rlp.Bytes, &blob) + err := rlp.ReadElem(d, rlp.Bytes, &blob) if err != nil { return err } From a22b3d32615f23c162267f198dca86e0da62bae5 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 07:47:38 -0500 Subject: [PATCH 27/34] yay --- rlp/decoder.go | 55 +++++++++++- types/txn_decode.go | 207 +++++++++++++++++++++++++------------------- types/txn_test.go | 2 +- 3 files changed, 173 insertions(+), 91 deletions(-) diff --git a/rlp/decoder.go b/rlp/decoder.go index 36e90a923..d946f07ca 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -47,9 +47,13 @@ func (d *Decoder) PeekByte() (n byte, err error) { return d.buf.PeekByte() } +func (d *Decoder) Rebase() { + d.buf.u = d.Bytes() + d.buf.off = 0 +} func (d *Decoder) Fork() *Decoder { return &Decoder{ - buf: newBuf(d.Bytes(), 0), + buf: newBuf(d.buf.u, d.buf.off), } } @@ -66,6 +70,55 @@ func (d *Decoder) ElemDec() (*Decoder, Token, error) { return NewDecoder(a), t, err } +func (d *Decoder) RawElem() ([]byte, Token, error) { + w := d.buf + start := w.Offset() + // figure out what we are reading + prefix, err := w.ReadByte() + if err != nil { + return nil, TokenUnknown, err + } + token := identifyToken(prefix) + + var ( + sz int + lenSz int + ) + // switch on the token + switch token { + case TokenDecimal: + // in this case, the value is just the byte itself + case TokenShortList: + sz = int(token.Diff(prefix)) + _, err = nextFull(w, sz) + case TokenLongList: + lenSz = int(token.Diff(prefix)) + sz, err = nextBeInt(w, lenSz) + if err != nil { + return nil, token, err + } + _, err = nextFull(w, sz) + case TokenShortBlob: + sz := int(token.Diff(prefix)) + _, err = nextFull(w, sz) + case TokenLongBlob: + lenSz := int(token.Diff(prefix)) + sz, err := nextBeInt(w, lenSz) + if err != nil { + return nil, token, err + } + _, err = nextFull(w, sz) + default: + return nil, token, fmt.Errorf("%w: unknown token", ErrDecode) + } + stop := w.Offset() + //log.Printf("%x %s\n", buf, token) + if err != nil { + return nil, token, err + } + return w.Underlying()[start:stop], token, nil +} + func (d *Decoder) Elem() ([]byte, Token, error) { w := d.buf // figure out what we are reading diff --git a/types/txn_decode.go b/types/txn_decode.go index 1033a3c1b..c163f5f4f 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "io" + "log" "math/bits" gokzg4844 "github.com/crate-crypto/go-kzg-4844" @@ -19,11 +20,12 @@ import ( // It also performs syntactic validation of the transactions. // wrappedWithBlobs means that for blob (type 3) transactions the full version with blobs/commitments/proofs is expected // (see https://eips.ethereum.org/EIPS/eip-4844#networking). -// [ ] || [ nonce, price, limit, to, value, data, y,r,s] -// 0x01 || [chain_id, nonce, price, limit, to, value, data, access_list, y,r,s] -// 0x02 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, y,r,s] -// 0x03 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s] -// 0x03 ||[[chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s], blobs, commitments, proofs] +// +// // [ ] || [ nonce, price, limit, to, value, data, y,r,s] +// // 0x01 || [chain_id, nonce, price, limit, to, value, data, access_list, y,r,s] +// // 0x02 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, y,r,s] +// // 0x03 || [chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s] +// // 0x03 ||[[chain_id, nonce, tip, price, limit, to, value, data, access_list, blob_price, blob_hash, y,r,s], blobs, commitments, proofs] func (ctx *TxParseContext) DecodeTransaction(decoder *rlp.Decoder, slot *TxSlot, sender []byte, hasEnvelope, wrappedWithBlobs bool, validateHash func([]byte) error) (err error) { err = ctx.decodeTransaction(decoder, slot, sender, hasEnvelope, wrappedWithBlobs, validateHash) if err != nil { @@ -40,6 +42,7 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, if ctx.withSender && len(sender) != 20 { return fmt.Errorf("expect sender buffer of len 20") } + // start classification peektok, err := decoder.PeekToken() // means that this is non-enveloped non-legacy transaction if peektok == rlp.TokenDecimal { @@ -48,57 +51,82 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, } } - if decoder.Empty() { - return fmt.Errorf("transaction must be either 1 list or 1 string") - } + var ( + dec *rlp.Decoder + ) - // if is blob type, it means that its an envelope, so we need to get out of that - if peektok.IsBlobType() { - decoder, _, err = decoder.ElemDec() - if err != nil { - return fmt.Errorf("size prefix: %w", err) //nolint - } - } - dec := decoder + var ( + parent *rlp.Decoder + bodyDecoder *rlp.Decoder + ) + var tok rlp.Token switch { default: return fmt.Errorf("expected list or blob token, got %s", peektok) case peektok.IsListType(): // Legacy transactions have list Prefix, - dec, peektok, err = decoder.ElemDec() + // enter the list + parent = decoder + bodyDecoder, _, err = decoder.ElemDec() if err != nil { return fmt.Errorf("size prefix: %w", err) //nolint } slot.Rlp = append([]byte{}, decoder.Consumed()...) slot.Type = LegacyTxType case peektok.IsBlobType() || peektok == rlp.TokenDecimal: // EIP-2718 transactions have string Prefix + // if is blob type, it means that its an envelope, so we need to get out of that + if peektok.IsBlobType() { + decoder, _, err = decoder.ElemDec() + if err != nil { + return fmt.Errorf("size prefix: %w", err) //nolint + } + if decoder.Empty() { + return fmt.Errorf("transaction must be either 1 list or 1 string") + } + } slot.Rlp = append([]byte{}, decoder.Bytes()...) // at this point, the next byte is the type - slot.Type, err = dec.ReadByte() + slot.Type, err = decoder.ReadByte() if err != nil { return fmt.Errorf("couldnt read txn type: %w", err) } if slot.Type > BlobTxType { return fmt.Errorf("unknown transaction type: %d", slot.Type) } - // now enter the list, since that is what we are in front of now - dec, _, err = dec.ElemDec() + // from here to the end of the element, if this is not a blob tx type with blobs, is the parent + parent = decoder.Fork() + parent.Rebase() + // now enter the list, since that is what we are in front of now. + dec, _, err = decoder.ElemDec() if err != nil { return fmt.Errorf("extract txn body: %w", err) //nolint } - - } - - bodyDecoder := dec - // if its a blob transaction, we actually need to enter a nested list, since its [tx_payload_body, blobs, commitments, proofs] - if slot.Type == BlobTxType && wrappedWithBlobs { - bodyDecoder, _, err = dec.ElemDec() - if err != nil { - return fmt.Errorf("wrapped blob tx: %w", err) //nolint + bodyDecoder = dec + if slot.Type == BlobTxType { + if wrappedWithBlobs { + // if its a blob transaction and wrapped with blobs, we actually need to enter a nested list + // in this case, "decoder" was an iterator for the array of [ [txbody...], blobs, commitments, proofs] + // so dec is now pointing at the head of the first element [txbody...] + tmp := dec.Fork() + tmp.Rebase() + parentBytes, _, err := tmp.RawElem() + if err != nil { + return fmt.Errorf("wrapped blob tx body: %w", err) //nolint + } + parent = rlp.NewDecoder(parentBytes) + bodyDecoder, _, err = dec.ElemDec() + if err != nil { + return fmt.Errorf("wrapped blob tx body: %w", err) //nolint + } + } else { + // otherwise its not wrapped with blobs, so we do nothing special + } } } - - err = ctx.decodeTransactionBody(bodyDecoder, decoder, slot, sender, validateHash) + log.Println("praent", parent) + log.Println("bodydec", bodyDecoder, tok) + log.Println("also", wrappedWithBlobs) + err = ctx.decodeTransactionBody(bodyDecoder, parent, slot, sender, validateHash) if err != nil { return fmt.Errorf("txn body: %w", err) } @@ -128,58 +156,6 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, return err } -func (ctx *TxParseContext) decodeCommitments(dec *rlp.Decoder, slot *TxSlot) (err error) { - err = dec.ForList(func(d *rlp.Decoder) error { - var blob gokzg4844.KZGCommitment - blobSlice := blob[:] - err := rlp.ReadElem(d, rlp.BytesExact, &blobSlice) - if err != nil { - return err - } - slot.Commitments = append(slot.Commitments, blob) - return nil - }) - - if err != nil { - return err - } - return nil -} - -func (ctx *TxParseContext) decodeProofs(dec *rlp.Decoder, slot *TxSlot) (err error) { - err = dec.ForList(func(d *rlp.Decoder) error { - var blob gokzg4844.KZGProof - blobSlice := blob[:] - err := rlp.ReadElem(d, rlp.BytesExact, &blobSlice) - if err != nil { - return err - } - slot.Proofs = append(slot.Proofs, blob) - return nil - }) - if err != nil { - return err - } - return nil -} - -func (ctx *TxParseContext) decodeBlobs(dec *rlp.Decoder, slot *TxSlot) (err error) { - err = dec.ForList(func(d *rlp.Decoder) error { - var blob []byte - err := rlp.ReadElem(d, rlp.Bytes, &blob) - if err != nil { - return err - } - slot.Blobs = append(slot.Blobs, blob) - return nil - }) - if err != nil { - return err - } - - return nil -} - func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.Decoder, slot *TxSlot, sender []byte, validateHash func([]byte) error) (err error) { legacy := slot.Type == LegacyTxType @@ -194,10 +170,13 @@ func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.D // for computing tx hash if !legacy { typeByte := []byte{slot.Type} + if _, err := k1.Write(typeByte); err != nil { + return err + } if _, err := k2.Write(typeByte); err != nil { return err } - if _, err := k1.Write(parent.Consumed()); err != nil { + if _, err := k1.Write(parent.Underlying()); err != nil { return fmt.Errorf("compute idHash: %w", err) } } @@ -301,10 +280,10 @@ func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.D if err != nil { return fmt.Errorf("blob fee cap: %s", err) //nolint } - dec.ForList(func(dec *rlp.Decoder) error { + dec.ForList(func(d *rlp.Decoder) error { var blob common.Hash blobSlice := blob[:] - err := rlp.ReadElem(dec, rlp.BytesExact, &blobSlice) + err := rlp.ReadElem(d, rlp.BytesExact, &blobSlice) if err != nil { return err } @@ -372,11 +351,9 @@ func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.D if err != nil { return fmt.Errorf("S: %s", err) //nolint } - // For legacy transactions, just hash the full payload if legacy { - u := parent.Consumed() - if _, err = k1.Write(u); err != nil { + if _, err = k1.Write(parent.Consumed()); err != nil { return fmt.Errorf("computing IdHash: %s", err) //nolint } } @@ -471,3 +448,55 @@ func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.D return nil } + +func (ctx *TxParseContext) decodeCommitments(dec *rlp.Decoder, slot *TxSlot) (err error) { + err = dec.ForList(func(d *rlp.Decoder) error { + var blob gokzg4844.KZGCommitment + blobSlice := blob[:] + err := rlp.ReadElem(d, rlp.BytesExact, &blobSlice) + if err != nil { + return err + } + slot.Commitments = append(slot.Commitments, blob) + return nil + }) + + if err != nil { + return err + } + return nil +} + +func (ctx *TxParseContext) decodeProofs(dec *rlp.Decoder, slot *TxSlot) (err error) { + err = dec.ForList(func(d *rlp.Decoder) error { + var blob gokzg4844.KZGProof + blobSlice := blob[:] + err := rlp.ReadElem(d, rlp.BytesExact, &blobSlice) + if err != nil { + return err + } + slot.Proofs = append(slot.Proofs, blob) + return nil + }) + if err != nil { + return err + } + return nil +} + +func (ctx *TxParseContext) decodeBlobs(dec *rlp.Decoder, slot *TxSlot) (err error) { + err = dec.ForList(func(d *rlp.Decoder) error { + var blob []byte + err := rlp.ReadElem(d, rlp.Bytes, &blob) + if err != nil { + return err + } + slot.Blobs = append(slot.Blobs, blob) + return nil + }) + if err != nil { + return err + } + + return nil +} diff --git a/types/txn_test.go b/types/txn_test.go index 311c27474..7cc616488 100644 --- a/types/txn_test.go +++ b/types/txn_test.go @@ -271,7 +271,7 @@ func TestBlobTxParsing(t *testing.T) { assert.Equal(t, thinTx.AlAddrCount, fatTx.AlAddrCount) assert.Equal(t, thinTx.AlStorCount, fatTx.AlStorCount) assert.Equal(t, thinTx.Gas, fatTx.Gas) - assert.Equal(t, thinTx.IDHash, fatTx.IDHash) + assert.Equal(t, thinTx.IDHash, fatTx.IDHash, "idhash") assert.Equal(t, thinTx.Creation, fatTx.Creation) assert.Equal(t, thinTx.BlobFeeCap, fatTx.BlobFeeCap) assert.Equal(t, thinTx.BlobHashes, fatTx.BlobHashes) From 6f7e90e26509d527a2c12bd6720a2a41ac6f41f7 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 07:54:50 -0500 Subject: [PATCH 28/34] ok --- rlp/decoder.go | 5 +++++ types/txn_decode.go | 30 +++++++++++++----------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/rlp/decoder.go b/rlp/decoder.go index d946f07ca..56a106856 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -70,6 +70,11 @@ func (d *Decoder) ElemDec() (*Decoder, Token, error) { return NewDecoder(a), t, err } +func (d *Decoder) RawElemDec() (*Decoder, Token, error) { + a, t, err := d.RawElem() + return NewDecoder(a), t, err +} + func (d *Decoder) RawElem() ([]byte, Token, error) { w := d.buf start := w.Offset() diff --git a/types/txn_decode.go b/types/txn_decode.go index c163f5f4f..d6ed0823c 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "fmt" "io" - "log" "math/bits" gokzg4844 "github.com/crate-crypto/go-kzg-4844" @@ -56,10 +55,9 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, ) var ( - parent *rlp.Decoder - bodyDecoder *rlp.Decoder + parent *rlp.Decoder // the parent should contain in its underlying buffer the rlp used for txn hash creation sans txn type + bodyDecoder *rlp.Decoder // the bodyDecoder should be an rlp decoder primed at the top of the list body for txn ) - var tok rlp.Token switch { default: @@ -94,27 +92,28 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, return fmt.Errorf("unknown transaction type: %d", slot.Type) } // from here to the end of the element, if this is not a blob tx type with blobs, is the parent - parent = decoder.Fork() - parent.Rebase() + parent, _, err = decoder.RawElemDec() + if err != nil { + return fmt.Errorf("extract txn body: %w", err) //nolint + } // now enter the list, since that is what we are in front of now. - dec, _, err = decoder.ElemDec() + bodyDecoder, _, err = parent.ElemDec() if err != nil { return fmt.Errorf("extract txn body: %w", err) //nolint } - bodyDecoder = dec if slot.Type == BlobTxType { if wrappedWithBlobs { + dec = bodyDecoder // if its a blob transaction and wrapped with blobs, we actually need to enter a nested list // in this case, "decoder" was an iterator for the array of [ [txbody...], blobs, commitments, proofs] - // so dec is now pointing at the head of the first element [txbody...] - tmp := dec.Fork() - tmp.Rebase() - parentBytes, _, err := tmp.RawElem() + // so our bodyDecoder is actually a pointer to the header of [txbody...]. + // we can extract the raw elem out of this in order to get the parent decoder + parent, _, err = bodyDecoder.RawElemDec() if err != nil { return fmt.Errorf("wrapped blob tx body: %w", err) //nolint } - parent = rlp.NewDecoder(parentBytes) - bodyDecoder, _, err = dec.ElemDec() + // and then the body is actually just the parent decoder read once, since we are sitting at the top of the [txbody...] header + bodyDecoder, _, err = parent.ElemDec() if err != nil { return fmt.Errorf("wrapped blob tx body: %w", err) //nolint } @@ -123,9 +122,6 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, } } } - log.Println("praent", parent) - log.Println("bodydec", bodyDecoder, tok) - log.Println("also", wrappedWithBlobs) err = ctx.decodeTransactionBody(bodyDecoder, parent, slot, sender, validateHash) if err != nil { return fmt.Errorf("txn body: %w", err) From 1cfd0b889189ef6d504cb0e58aedd74c78297988 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 07:55:26 -0500 Subject: [PATCH 29/34] peek --- types/txn_decode.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/types/txn_decode.go b/types/txn_decode.go index d6ed0823c..f5b29df49 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -42,11 +42,11 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, return fmt.Errorf("expect sender buffer of len 20") } // start classification - peektok, err := decoder.PeekToken() + token, err := decoder.PeekToken() // means that this is non-enveloped non-legacy transaction - if peektok == rlp.TokenDecimal { + if token == rlp.TokenDecimal { if hasEnvelope { - return fmt.Errorf("expected envelope in the payload, got %s", peektok) + return fmt.Errorf("expected envelope in the payload, got %s", token) } } @@ -61,8 +61,8 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, switch { default: - return fmt.Errorf("expected list or blob token, got %s", peektok) - case peektok.IsListType(): // Legacy transactions have list Prefix, + return fmt.Errorf("expected list or blob token, got %s", token) + case token.IsListType(): // Legacy transactions have list Prefix, // enter the list parent = decoder bodyDecoder, _, err = decoder.ElemDec() @@ -71,9 +71,9 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, } slot.Rlp = append([]byte{}, decoder.Consumed()...) slot.Type = LegacyTxType - case peektok.IsBlobType() || peektok == rlp.TokenDecimal: // EIP-2718 transactions have string Prefix + case token.IsBlobType() || token == rlp.TokenDecimal: // EIP-2718 transactions have string Prefix // if is blob type, it means that its an envelope, so we need to get out of that - if peektok.IsBlobType() { + if token.IsBlobType() { decoder, _, err = decoder.ElemDec() if err != nil { return fmt.Errorf("size prefix: %w", err) //nolint From a382cc4067178ea64a8a7ddb588c6f7d5b0af74e Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 07:57:30 -0500 Subject: [PATCH 30/34] fizz --- types/txn_types_fuzz_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/types/txn_types_fuzz_test.go b/types/txn_types_fuzz_test.go index 1f43a695d..110498c19 100644 --- a/types/txn_types_fuzz_test.go +++ b/types/txn_types_fuzz_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/ledgerwatch/erigon-lib/common/u256" + "github.com/ledgerwatch/erigon-lib/rlp" ) // golang.org/s/draft-fuzzing-design @@ -27,6 +28,6 @@ func FuzzParseTx(f *testing.F) { ctx := NewTxParseContext(*u256.N1) txn := &TxSlot{} sender := make([]byte, 20) - _, _ = ctx.ParseTransaction(in, pos, txn, sender, false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) + ctx.DecodeTransaction(rlp.NewDecoder(in), txn, sender, false /* hasEnvelope */, true /* wrappedWithBlobs */, nil) }) } From e11a96460d0e840f746d1651369eda190655210e Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 08:00:32 -0500 Subject: [PATCH 31/34] lint --- rlp/decoder.go | 6 ++++++ rlp/util.go | 2 +- types/txn_decode.go | 10 +++++++--- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/rlp/decoder.go b/rlp/decoder.go index 56a106856..d6cad3fae 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -113,6 +113,9 @@ func (d *Decoder) RawElem() ([]byte, Token, error) { return nil, token, err } _, err = nextFull(w, sz) + if err != nil { + return nil, token, err + } default: return nil, token, fmt.Errorf("%w: unknown token", ErrDecode) } @@ -163,6 +166,9 @@ func (d *Decoder) Elem() ([]byte, Token, error) { return nil, token, err } buf, err = nextFull(w, sz) + if err != nil { + return nil, token, err + } default: return nil, token, fmt.Errorf("%w: unknown token", ErrDecode) } diff --git a/rlp/util.go b/rlp/util.go index 6b04dfb2a..0219e1d39 100644 --- a/rlp/util.go +++ b/rlp/util.go @@ -52,7 +52,7 @@ const ( func identifyToken(b byte) Token { switch { - case b >= 0 && b <= 127: + case b <= 127: return TokenDecimal case b >= 128 && b <= 183: return TokenShortBlob diff --git a/types/txn_decode.go b/types/txn_decode.go index f5b29df49..549b0622e 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -43,6 +43,9 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, } // start classification token, err := decoder.PeekToken() + if err != nil { + return err + } // means that this is non-enveloped non-legacy transaction if token == rlp.TokenDecimal { if hasEnvelope { @@ -117,8 +120,6 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, if err != nil { return fmt.Errorf("wrapped blob tx body: %w", err) //nolint } - } else { - // otherwise its not wrapped with blobs, so we do nothing special } } } @@ -269,6 +270,9 @@ func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.D }) return err }) + if err != nil { + return err + } } if slot.Type == BlobTxType { @@ -380,7 +384,7 @@ func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.D return fmt.Errorf("computing signHash (hashing len Prefix): %s", err) //nolint } } else { - beLen := common.BitLenToByteLen(bits.Len(uint(sigHashLen))) + beLen := common.BitLenToByteLen(bits.Len(sigHashLen)) binary.BigEndian.PutUint64(ctx.buf[1:], uint64(sigHashLen)) ctx.buf[8-beLen] = byte(beLen) + 247 if _, err := k2.Write(ctx.buf[8-beLen : 9]); err != nil { From ecf5e217bfb27a0f4902e1ecc5009a8833a09b8e Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 08:02:54 -0500 Subject: [PATCH 32/34] a --- types/txn_decode.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/types/txn_decode.go b/types/txn_decode.go index 549b0622e..2461593f5 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -56,7 +56,6 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, var ( dec *rlp.Decoder ) - var ( parent *rlp.Decoder // the parent should contain in its underlying buffer the rlp used for txn hash creation sans txn type bodyDecoder *rlp.Decoder // the bodyDecoder should be an rlp decoder primed at the top of the list body for txn @@ -381,14 +380,14 @@ func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.D if sigHashLen < 56 { ctx.buf[0] = byte(sigHashLen) + 192 if _, err := k2.Write(ctx.buf[:1]); err != nil { - return fmt.Errorf("computing signHash (hashing len Prefix): %s", err) //nolint + return fmt.Errorf("computing signHash (hashing len Prefix1): %s", err) //nolint } } else { beLen := common.BitLenToByteLen(bits.Len(sigHashLen)) binary.BigEndian.PutUint64(ctx.buf[1:], uint64(sigHashLen)) ctx.buf[8-beLen] = byte(beLen) + 247 if _, err := k2.Write(ctx.buf[8-beLen : 9]); err != nil { - return fmt.Errorf("computing signHash (hashing len Prefix): %s", err) //nolint + return fmt.Errorf("computing signHash (hashing len Prefix2): %s", err) //nolint } } if _, err = k2.Write(dec.Underlying()[sigHashPos:sigHashEnd]); err != nil { From 0fbb99e39b95eae0de92c76f1ea47e3a3681dbc7 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 08:06:10 -0500 Subject: [PATCH 33/34] a --- types/txn_decode.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/types/txn_decode.go b/types/txn_decode.go index 2461593f5..3d44a168e 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -167,10 +167,10 @@ func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.D if !legacy { typeByte := []byte{slot.Type} if _, err := k1.Write(typeByte); err != nil { - return err + return fmt.Errorf("compute idHash: %w", err) } if _, err := k2.Write(typeByte); err != nil { - return err + return fmt.Errorf("compute Hash: %w", err) } if _, err := k1.Write(parent.Underlying()); err != nil { return fmt.Errorf("compute idHash: %w", err) From 860462f7688b88e6e7b5d01cec61068d4e8400e2 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 7 Sep 2023 08:18:48 -0500 Subject: [PATCH 34/34] test --- rlp/decoder.go | 16 +++++++--------- types/txn_decode.go | 17 ++++++++++------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/rlp/decoder.go b/rlp/decoder.go index d6cad3fae..4f41ecd7f 100644 --- a/rlp/decoder.go +++ b/rlp/decoder.go @@ -1,6 +1,7 @@ package rlp import ( + "errors" "fmt" "io" ) @@ -108,14 +109,11 @@ func (d *Decoder) RawElem() ([]byte, Token, error) { _, err = nextFull(w, sz) case TokenLongBlob: lenSz := int(token.Diff(prefix)) - sz, err := nextBeInt(w, lenSz) + sz, err = nextBeInt(w, lenSz) if err != nil { return nil, token, err } _, err = nextFull(w, sz) - if err != nil { - return nil, token, err - } default: return nil, token, fmt.Errorf("%w: unknown token", ErrDecode) } @@ -161,20 +159,17 @@ func (d *Decoder) Elem() ([]byte, Token, error) { buf, err = nextFull(w, sz) case TokenLongBlob: lenSz := int(token.Diff(prefix)) - sz, err := nextBeInt(w, lenSz) + sz, err = nextBeInt(w, lenSz) if err != nil { return nil, token, err } buf, err = nextFull(w, sz) - if err != nil { - return nil, token, err - } default: return nil, token, fmt.Errorf("%w: unknown token", ErrDecode) } //log.Printf("%x %s\n", buf, token) if err != nil { - return nil, token, err + return nil, token, fmt.Errorf("read data: %w", err) } return buf, token, nil } @@ -210,6 +205,9 @@ func (d *Decoder) ForList(fn func(*Decoder) error) error { return nil } err := fn(dec) + if errors.Is(err, io.EOF) { + return nil + } if err != nil { return err } diff --git a/types/txn_decode.go b/types/txn_decode.go index 3d44a168e..7f1c6c2da 100644 --- a/types/txn_decode.go +++ b/types/txn_decode.go @@ -96,7 +96,7 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, // from here to the end of the element, if this is not a blob tx type with blobs, is the parent parent, _, err = decoder.RawElemDec() if err != nil { - return fmt.Errorf("extract txn body: %w", err) //nolint + return fmt.Errorf("extract txn body parent: %w", err) //nolint } // now enter the list, since that is what we are in front of now. bodyDecoder, _, err = parent.ElemDec() @@ -112,7 +112,7 @@ func (ctx *TxParseContext) decodeTransaction(decoder *rlp.Decoder, slot *TxSlot, // we can extract the raw elem out of this in order to get the parent decoder parent, _, err = bodyDecoder.RawElemDec() if err != nil { - return fmt.Errorf("wrapped blob tx body: %w", err) //nolint + return fmt.Errorf("wrapped blob tx body parent: %w", err) //nolint } // and then the body is actually just the parent decoder read once, since we are sitting at the top of the [txbody...] header bodyDecoder, _, err = parent.ElemDec() @@ -179,7 +179,7 @@ func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.D if ctx.validateRlp != nil { if err := ctx.validateRlp(slot.Rlp); err != nil { - return err + return fmt.Errorf("validate rlp: %w", err) } } @@ -257,20 +257,23 @@ func (ctx *TxParseContext) decodeTransactionBody(dec *rlp.Decoder, parent *rlp.D slot.AlAddrCount++ err := rlp.ReadElem(ld, rlp.Skip, nil) if err != nil { - return err + return fmt.Errorf("elem: %w", err) } err = ld.ForList(func(sk *rlp.Decoder) error { slot.AlStorCount++ err := rlp.ReadElem(sk, rlp.Skip, nil) if err != nil { - return err + return fmt.Errorf("elem: %w", err) } return nil }) - return err + if err != nil { + return fmt.Errorf("iterate: %w", err) + } + return nil }) if err != nil { - return err + return fmt.Errorf("iterate: %w", err) } }