diff --git a/blockchain/blockchain.go b/blockchain/blockchain.go index 0aa2e63de5..4dd346f5d1 100644 --- a/blockchain/blockchain.go +++ b/blockchain/blockchain.go @@ -144,8 +144,9 @@ func (b *Blockchain) Head() (*core.Block, error) { return nil, err } - txn := b.database.NewIndexedBatch() - return b.transactionLayout.BlockByNumber(txn, curHeight) + snapshot := b.database.NewSnapshot() + defer snapshot.Close() + return b.transactionLayout.BlockByNumber(snapshot, curHeight) } func (b *Blockchain) HeadsHeader() (*core.Header, error) { @@ -169,8 +170,9 @@ func headsHeader(txn db.KeyValueReader) (*core.Header, error) { func (b *Blockchain) BlockByNumber(number uint64) (*core.Block, error) { b.listener.OnRead("BlockByNumber") - txn := b.database.NewIndexedBatch() - return b.transactionLayout.BlockByNumber(txn, number) + snapshot := b.database.NewSnapshot() + defer snapshot.Close() + return b.transactionLayout.BlockByNumber(snapshot, number) } func (b *Blockchain) BlockHeaderByNumber(number uint64) (*core.Header, error) { @@ -190,8 +192,9 @@ func (b *Blockchain) BlockByHash(hash *felt.Felt) (*core.Block, error) { return nil, err } - txn := b.database.NewIndexedBatch() - return b.transactionLayout.BlockByNumber(txn, blockNum) + snapshot := b.database.NewSnapshot() + defer snapshot.Close() + return b.transactionLayout.BlockByNumber(snapshot, blockNum) } func (b *Blockchain) BlockHeaderByHash(hash *felt.Felt) (*core.Header, error) { @@ -302,57 +305,62 @@ func (b *Blockchain) Store( stateUpdate *core.StateUpdate, newClasses map[felt.Felt]core.ClassDefinition, ) error { - err := b.database.Update(func(txn db.IndexedBatch) error { - if err := verifyBlock(txn, block); err != nil { - return err - } - - if err := core.NewState(txn).Update(block.Number, stateUpdate, newClasses, false); err != nil { - return err - } - if err := core.WriteBlockHeader(txn, block.Header); err != nil { - return err - } + snapshot := b.database.NewSnapshot() + defer snapshot.Close() + batch := b.database.NewBatch() + txn := db.NewSnapshotBatch(batch, snapshot) + if err := verifyBlock(txn, block); err != nil { + return err + } - err := b.transactionLayout.WriteTransactionsAndReceipts( - txn, - block.Number, - block.Transactions, - block.Receipts, - ) - if err != nil { - return err - } + if err := core.NewState(txn).Update(block.Number, stateUpdate, newClasses, false); err != nil { + return err + } + if err := core.WriteBlockHeader(txn, block.Header); err != nil { + return err + } - if err := core.WriteStateUpdateByBlockNum(txn, block.Number, stateUpdate); err != nil { - return err - } + err := b.transactionLayout.WriteTransactionsAndReceipts( + txn, + block.Number, + block.Transactions, + block.Receipts, + ) + if err != nil { + return err + } - if err := core.WriteBlockCommitment(txn, block.Number, blockCommitments); err != nil { - return err - } + if err := core.WriteStateUpdateByBlockNum(txn, block.Number, stateUpdate); err != nil { + return err + } - if err := core.WriteL1HandlerMsgHashes(txn, block.Transactions); err != nil { - return err - } + if err := core.WriteBlockCommitment(txn, block.Number, blockCommitments); err != nil { + return err + } - err = storeCasmHashMetadata( - txn, - block.Number, - block.ProtocolVersion, - stateUpdate, - newClasses, - ) - if err != nil { - return err - } + if err := core.WriteL1HandlerMsgHashes(txn, block.Transactions); err != nil { + return err + } - return core.WriteChainHeight(txn, block.Number) - }) + err = storeCasmHashMetadata( + txn, + block.Number, + block.ProtocolVersion, + stateUpdate, + newClasses, + ) if err != nil { return err } + if err := core.WriteChainHeight(txn, block.Number); err != nil { + return err + } + + if err := txn.Write(); err != nil { + return err + } + return b.runningFilter.Insert( block.EventsBloom, block.Number, @@ -362,7 +370,7 @@ func (b *Blockchain) Store( // storeCasmHashMetadata stores CASM hash metadata for declared and migrated classes. // See [core.ClassCasmHashMetadata] func storeCasmHashMetadata( - txn db.IndexedBatch, + txn db.SnapshotBatch, blockNumber uint64, protocolVersion string, stateUpdate *core.StateUpdate, @@ -385,7 +393,7 @@ func storeCasmHashMetadata( // storeCasmHashMetadataV2 stores metadata for classes declared with casm hash v2 or // migrated from v1. casm hash v2 is after protocol version >= 0.14.1. func storeCasmHashMetadataV2( - txn db.IndexedBatch, + txn db.SnapshotBatch, blockNumber uint64, stateUpdate *core.StateUpdate, ) error { @@ -431,7 +439,7 @@ func storeCasmHashMetadataV2( // storeDeclaredV1Classes stores metadata for classes declared with V1 hash (protocol < 0.14.1). // It computes the V2 hash from the class definition. func storeCasmHashMetadataV1( - txn db.IndexedBatch, + txn db.SnapshotBatch, blockNumber uint64, stateUpdate *core.StateUpdate, newClasses map[felt.Felt]core.ClassDefinition, @@ -527,33 +535,34 @@ func (b *Blockchain) SanityCheckNewHeight(block *core.Block, stateUpdate *core.S type StateCloser = func() error -var noopStateCloser = func() error { return nil } // TODO: remove this once we refactor the state - // HeadState returns a StateReader that provides a stable view to the latest state func (b *Blockchain) HeadState() (core.StateReader, StateCloser, error) { b.listener.OnRead("HeadState") - txn := b.database.NewIndexedBatch() + snapshot := b.database.NewSnapshot() - _, err := core.GetChainHeight(txn) + _, err := core.GetChainHeight(snapshot) if err != nil { + snapshot.Close() return nil, nil, err } - return core.NewState(txn), noopStateCloser, nil + return core.NewStateSnapshotReader(snapshot), snapshot.Close, nil } // StateAtBlockNumber returns a StateReader that provides // a stable view to the state at the given block number func (b *Blockchain) StateAtBlockNumber(blockNumber uint64) (core.StateReader, StateCloser, error) { b.listener.OnRead("StateAtBlockNumber") - txn := b.database.NewIndexedBatch() + snapshot := b.database.NewSnapshot() - _, err := core.GetBlockHeaderByNumber(txn, blockNumber) + _, err := core.GetBlockHeaderByNumber(snapshot, blockNumber) if err != nil { + snapshot.Close() return nil, nil, err } - return core.NewDeprecatedStateHistory(core.NewState(txn), blockNumber), noopStateCloser, nil + stateReader := core.NewStateSnapshotReader(snapshot) + return core.NewDeprecatedStateHistory(stateReader, blockNumber), snapshot.Close, nil } // StateAtBlockHash returns a StateReader that provides @@ -562,18 +571,20 @@ func (b *Blockchain) StateAtBlockHash(blockHash *felt.Felt) (core.StateReader, S b.listener.OnRead("StateAtBlockHash") if blockHash.IsZero() { memDB := memory.New() - txn := memDB.NewIndexedBatch() - emptyState := core.NewState(txn) - return emptyState, noopStateCloser, nil + snapshot := memDB.NewSnapshot() + emptyState := core.NewStateSnapshotReader(snapshot) + return emptyState, snapshot.Close, nil } - txn := b.database.NewIndexedBatch() - header, err := core.GetBlockHeaderByHash(txn, blockHash) + snapshot := b.database.NewSnapshot() + header, err := core.GetBlockHeaderByHash(snapshot, blockHash) if err != nil { + snapshot.Close() return nil, nil, err } - return core.NewDeprecatedStateHistory(core.NewState(txn), header.Number), noopStateCloser, nil + stateReader := core.NewStateSnapshotReader(snapshot) + return core.NewDeprecatedStateHistory(stateReader, header.Number), snapshot.Close, nil } // EventFilter returns an EventFilter object that is tied to a snapshot of the blockchain @@ -603,23 +614,36 @@ func (b *Blockchain) EventFilter( // RevertHead reverts the head block func (b *Blockchain) RevertHead() error { - return b.database.Update(b.revertHead) + snapshot := b.database.NewSnapshot() + defer snapshot.Close() + batch := b.database.NewBatch() + txn := db.NewSnapshotBatch(batch, snapshot) + if err := b.revertHead(txn); err != nil { + return err + } + if err := txn.Write(); err != nil { + return err + } + return nil } func (b *Blockchain) GetReverseStateDiff() (core.StateDiff, error) { - txn := b.database.NewIndexedBatch() - blockNum, err := core.GetChainHeight(txn) + var reverseStateDiff core.StateDiff + + snapshot := b.database.NewSnapshot() + defer snapshot.Close() + blockNum, err := core.GetChainHeight(snapshot) if err != nil { return core.StateDiff{}, err } - stateUpdate, err := core.GetStateUpdateByBlockNum(txn, blockNum) + stateUpdate, err := core.GetStateUpdateByBlockNum(snapshot, blockNum) if err != nil { return core.StateDiff{}, err } - state := core.NewState(txn) - reverseStateDiff, err := state.GetReverseStateDiff(blockNum, stateUpdate.StateDiff) + state := core.NewStateSnapshotReader(snapshot) + reverseStateDiff, err = state.GetReverseStateDiff(blockNum, stateUpdate.StateDiff) if err != nil { return core.StateDiff{}, err } @@ -627,7 +651,7 @@ func (b *Blockchain) GetReverseStateDiff() (core.StateDiff, error) { return reverseStateDiff, nil } -func (b *Blockchain) revertHead(txn db.IndexedBatch) error { +func (b *Blockchain) revertHead(txn db.SnapshotBatch) error { blockNumber, err := core.GetChainHeight(txn) if err != nil { return err @@ -698,8 +722,9 @@ func (b *Blockchain) Simulate( sign utils.BlockSignFunc, ) (SimulateResult, error) { // Simulate without commit - txn := b.database.NewIndexedBatch() - defer txn.Close() + snapshot := b.database.NewSnapshot() + defer snapshot.Close() + txn := db.NewSnapshotBatch(nil, snapshot) if err := b.updateStateRoots(txn, block, stateUpdate, newClasses); err != nil { return SimulateResult{}, err @@ -734,44 +759,50 @@ func (b *Blockchain) Finalise( newClasses map[felt.Felt]core.ClassDefinition, sign utils.BlockSignFunc, ) error { - err := b.database.Update(func(txn db.IndexedBatch) error { - if err := b.updateStateRoots(txn, block, stateUpdate, newClasses); err != nil { - return err - } - commitments, err := b.updateBlockHash(block, stateUpdate) - if err != nil { - return err - } - if err := b.signBlock(block, stateUpdate, sign); err != nil { - return err - } - if err := b.storeBlockData(txn, block, stateUpdate, commitments); err != nil { - return err - } + snapshot := b.database.NewSnapshot() + defer snapshot.Close() + batch := b.database.NewBatch() + txn := db.NewSnapshotBatch(batch, snapshot) - err = storeCasmHashMetadata( - txn, - block.Number, - block.ProtocolVersion, - stateUpdate, - newClasses, - ) - if err != nil { - return err - } + if err := b.updateStateRoots(txn, block, stateUpdate, newClasses); err != nil { + return err + } + commitments, err := b.updateBlockHash(block, stateUpdate) + if err != nil { + return err + } + if err := b.signBlock(block, stateUpdate, sign); err != nil { + return err + } + if err := b.storeBlockData(txn, block, stateUpdate, commitments); err != nil { + return err + } - return core.WriteChainHeight(txn, block.Number) - }) + err = storeCasmHashMetadata( + txn, + block.Number, + block.ProtocolVersion, + stateUpdate, + newClasses, + ) if err != nil { return err } + if err := core.WriteChainHeight(txn, block.Number); err != nil { + return err + } + + if err := txn.Write(); err != nil { + return err + } + return b.runningFilter.Insert(block.EventsBloom, block.Number) } // updateStateRoots computes and updates state roots in the block and state update func (b *Blockchain) updateStateRoots( - txn db.IndexedBatch, + txn db.SnapshotBatch, block *core.Block, stateUpdate *core.StateUpdate, newClasses map[felt.Felt]core.ClassDefinition, @@ -840,7 +871,7 @@ func (b *Blockchain) signBlock( // storeBlockData persists all block-related data to the database func (b *Blockchain) storeBlockData( - txn db.IndexedBatch, + txn db.SnapshotBatch, block *core.Block, stateUpdate *core.StateUpdate, commitments *core.BlockCommitments, diff --git a/consensus/db/db.go b/consensus/db/db.go index 547cf716bb..f69bdd7eba 100644 --- a/consensus/db/db.go +++ b/consensus/db/db.go @@ -109,8 +109,6 @@ func (s *tendermintDB[V, H, A]) SetWALEntry(entry wal.Entry[V, H, A]) error { func (s *tendermintDB[V, H, A]) LoadAllEntries() iter.Seq2[wal.Entry[V, H, A], error] { return func(yield func(wal.Entry[V, H, A], error) bool) { err := s.db.View(func(snap db.Snapshot) error { - defer snap.Close() - iter, err := snap.NewIterator(WALEntryBucket.Key(), true) if err != nil { return fmt.Errorf("failed to create iter: %w", err) diff --git a/core/contract.go b/core/contract.go index f0e3130096..098caa39a5 100644 --- a/core/contract.go +++ b/core/contract.go @@ -19,7 +19,7 @@ var ( // NewContractUpdater creates an updater for the contract instance at the given address. // Deploy should be called for contracts that were just deployed to the network. -func NewContractUpdater(addr *felt.Felt, txn db.IndexedBatch) (*ContractUpdater, error) { +func NewContractUpdater(addr *felt.Felt, txn db.SnapshotBatch) (*ContractUpdater, error) { contractDeployed, err := deployed(addr, txn) if err != nil { return nil, err @@ -36,7 +36,7 @@ func NewContractUpdater(addr *felt.Felt, txn db.IndexedBatch) (*ContractUpdater, } // DeployContract sets up the database for a new contract. -func DeployContract(addr, classHash *felt.Felt, txn db.IndexedBatch) (*ContractUpdater, error) { +func DeployContract(addr, classHash *felt.Felt, txn db.SnapshotBatch) (*ContractUpdater, error) { contractDeployed, err := deployed(addr, txn) if err != nil { return nil, err @@ -84,7 +84,7 @@ func ContractAddress( ) } -func deployed(addr *felt.Felt, txn db.IndexedBatch) (bool, error) { +func deployed(addr *felt.Felt, txn db.KeyValueReader) (bool, error) { _, err := ContractClassHash(addr, txn) if errors.Is(err, db.ErrKeyNotFound) { return false, nil @@ -100,7 +100,7 @@ type ContractUpdater struct { // Address that this contract instance is deployed to Address *felt.Felt // txn to access the database - txn db.IndexedBatch + txn db.SnapshotBatch } // Purge eliminates the contract instance, deleting all associated data from storage @@ -131,7 +131,7 @@ func (c *ContractUpdater) UpdateNonce(nonce *felt.Felt) error { } // ContractRoot returns the root of the contract storage. -func ContractRoot(addr *felt.Felt, txn db.IndexedBatch) (felt.Felt, error) { +func ContractRoot(addr *felt.Felt, txn db.SnapshotBatch) (felt.Felt, error) { cStorage, err := storage(addr, txn) if err != nil { return felt.Felt{}, err @@ -177,7 +177,7 @@ func ContractClassHash(addr *felt.Felt, txn db.KeyValueReader) (felt.Felt, error return GetContractClassHash(txn, addr) } -func setClassHash(txn db.IndexedBatch, addr, classHash *felt.Felt) error { +func setClassHash(txn db.KeyValueWriter, addr, classHash *felt.Felt) error { classHashKey := db.ContractClassHashKey(addr) return txn.Put(classHashKey, classHash.Marshal()) } @@ -189,7 +189,7 @@ func (c *ContractUpdater) Replace(classHash *felt.Felt) error { // storage returns the [core.Trie] that represents the // storage of the contract. -func storage(addr *felt.Felt, txn db.IndexedBatch) (*trie.Trie, error) { +func storage(addr *felt.Felt, txn db.SnapshotBatch) (*trie.Trie, error) { addrBytes := addr.Marshal() return trie.NewTriePedersen(txn, db.ContractStorage.Key(addrBytes), ContractStorageTrieHeight) } diff --git a/core/state.go b/core/state.go index 9958ef86c0..25c1a3c5b5 100644 --- a/core/state.go +++ b/core/state.go @@ -20,20 +20,23 @@ import ( const globalTrieHeight = 251 var ( - stateVersion = felt.NewFromBytes[felt.Felt]([]byte(`STARKNET_STATE_V0`)) - leafVersion = felt.NewFromBytes[felt.Felt]([]byte(`CONTRACT_CLASS_LEAF_V0`)) - ErrCheckHeadState = errors.New("check head state") + stateVersion = felt.NewFromBytes[felt.Felt]([]byte(`STARKNET_STATE_V0`)) + leafVersion = felt.NewFromBytes[felt.Felt]([]byte(`CONTRACT_CLASS_LEAF_V0`)) + ErrCheckHeadState = errors.New("check head state") + systemContractsClassHash = felt.NewFromUint64[felt.Felt](0) + systemContracts = map[felt.Felt]struct{}{ + *felt.NewFromUint64[felt.Felt](1): {}, + *felt.NewFromUint64[felt.Felt](2): {}, + } ) -var _ StateHistoryReader = (*State)(nil) - type State struct { - txn db.IndexedBatch + txn db.SnapshotBatch *StateSnapshotReader } func NewState( - txn db.IndexedBatch, + txn db.SnapshotBatch, ) *State { return &State{ txn: txn, @@ -135,15 +138,6 @@ func (s *State) Update( return s.verifyStateUpdateRoot(update.NewRoot) } -var ( - systemContractsClassHash = new(felt.Felt).SetUint64(0) - - systemContracts = map[felt.Felt]struct{}{ - felt.FromUint64[felt.Felt](1): {}, - felt.FromUint64[felt.Felt](2): {}, - } -) - func (s *State) updateContracts(stateTrie *trie.Trie, blockNumber uint64, diff *StateDiff, logChanges bool) error { // replace contract instances for addr, classHash := range diff.ReplacedClasses { @@ -639,16 +633,16 @@ func (s *State) revertMigratedCasmClasses( } // storage returns a [core.Trie] that represents the Starknet global state in the given Txn context. -func contractTrie(txn db.IndexedBatch) (*trie.Trie, func() error, error) { +func contractTrie(txn db.SnapshotBatch) (*trie.Trie, func() error, error) { return globalTrie(txn, db.StateTrie, trie.NewTriePedersen) } -func classesTrie(txn db.IndexedBatch) (*trie.Trie, func() error, error) { +func classesTrie(txn db.SnapshotBatch) (*trie.Trie, func() error, error) { return globalTrie(txn, db.ClassesTrie, trie.NewTriePoseidon) } func globalTrie( - txn db.IndexedBatch, + txn db.SnapshotBatch, bucket db.Bucket, newTrie trie.NewTrieFunc, ) (*trie.Trie, func() error, error) { @@ -666,8 +660,9 @@ func globalTrie( return nil, nil, err } - rootKey := new(trie.BitArray) + var rootKey *trie.BitArray if len(val) > 0 { + rootKey = new(trie.BitArray) err = rootKey.UnmarshalBinary(val) if err != nil { return nil, nil, err @@ -687,7 +682,8 @@ func globalTrie( resultingRootKey := gTrie.RootKey() // no updates on the trie, short circuit and return - if resultingRootKey.Equal(rootKey) { + if (resultingRootKey == nil && rootKey == nil) || + (resultingRootKey != nil && rootKey != nil && resultingRootKey.Equal(rootKey)) { return nil } diff --git a/core/state/commonstate/state.go b/core/state/commonstate/state.go index 824dd0777d..22963df8fe 100644 --- a/core/state/commonstate/state.go +++ b/core/state/commonstate/state.go @@ -62,7 +62,7 @@ func NewStateFactory( }, nil } -func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.IndexedBatch) (State, error) { +func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.SnapshotBatch) (State, error) { if !sf.UseNewState { deprecatedState := core.NewState(txn) return deprecatedState, nil @@ -77,7 +77,7 @@ func (sf *StateFactory) NewState(stateRoot *felt.Felt, txn db.IndexedBatch) (Sta func (sf *StateFactory) NewStateReader( stateRoot *felt.Felt, - txn db.IndexedBatch, + txn db.SnapshotBatch, blockNumber uint64, ) (StateReader, error) { if !sf.UseNewState { @@ -96,7 +96,9 @@ func (sf *StateFactory) NewStateReader( func (sf *StateFactory) EmptyState() (StateReader, error) { if !sf.UseNewState { memDB := memory.New() - txn := memDB.NewIndexedBatch() + snapshot := memDB.NewSnapshot() + txn := db.NewSnapshotBatch(nil, snapshot) + defer snapshot.Close() emptyState := core.NewState(txn) return emptyState, nil } diff --git a/db/batch.go b/db/batch.go index c358dd398b..bd81b2810a 100644 --- a/db/batch.go +++ b/db/batch.go @@ -43,3 +43,8 @@ type IndexedBatcher interface { NewIndexedBatch() IndexedBatch NewIndexedBatchWithSize(size int) IndexedBatch } + +type SnapshotBatch interface { + Batch + KeyValueReader +} diff --git a/db/memory/db.go b/db/memory/db.go index 4193db05d9..95c7b02fcf 100644 --- a/db/memory/db.go +++ b/db/memory/db.go @@ -180,6 +180,7 @@ func (d *Database) View(fn func(db.Snapshot) error) error { } snap := d.NewSnapshot() + defer snap.Close() return fn(snap) } diff --git a/db/pebble/db.go b/db/pebble/db.go index 342e2cf2b0..e26d5486b9 100644 --- a/db/pebble/db.go +++ b/db/pebble/db.go @@ -76,6 +76,7 @@ func (d *DB) Update(fn func(w db.IndexedBatch) error) error { func (d *DB) View(fn func(r db.Snapshot) error) error { snap := d.NewSnapshot() + defer snap.Close() return fn(snap) } diff --git a/db/pebblev2/db.go b/db/pebblev2/db.go index 8f7cd781be..4e9e0adc28 100644 --- a/db/pebblev2/db.go +++ b/db/pebblev2/db.go @@ -83,6 +83,7 @@ func (d *DB) Update(fn func(w db.IndexedBatch) error) error { func (d *DB) View(fn func(r db.Snapshot) error) error { snap := d.NewSnapshot() + defer snap.Close() return fn(snap) } diff --git a/db/snapshot_batch.go b/db/snapshot_batch.go new file mode 100644 index 0000000000..1165cd56df --- /dev/null +++ b/db/snapshot_batch.go @@ -0,0 +1,142 @@ +package db + +import ( + "bytes" + "slices" +) + +var _ SnapshotBatch = (*snapshotBatch)(nil) + +type snapshotBatch struct { + batch Batch + snapshot Snapshot + writes map[string][]byte + deletes map[string]struct{} + size int + ranges []deleteRange +} + +type deleteRange struct { + start []byte + end []byte +} + +func NewSnapshotBatch(batch Batch, snapshot Snapshot) *snapshotBatch { + return &snapshotBatch{ + batch: batch, + snapshot: snapshot, + writes: make(map[string][]byte), + deletes: make(map[string]struct{}), + } +} + +func (b *snapshotBatch) Put(key, value []byte) error { + keyStr := string(key) + delete(b.deletes, keyStr) + b.writes[keyStr] = slices.Clone(value) + b.size += len(key) + len(value) + return nil +} + +func (b *snapshotBatch) Delete(key []byte) error { + keyStr := string(key) + delete(b.writes, keyStr) + b.deletes[keyStr] = struct{}{} + b.size += len(key) + return nil +} + +func (b *snapshotBatch) DeleteRange(start, end []byte) error { + b.ranges = append(b.ranges, deleteRange{ + start: slices.Clone(start), + end: slices.Clone(end), + }) + return nil +} + +func (b *snapshotBatch) Size() int { + return b.size +} + +func (b *snapshotBatch) Flush() error { + for _, r := range b.ranges { + if err := b.batch.DeleteRange(r.start, r.end); err != nil { + return err + } + } + for k, entry := range b.writes { + key := []byte(k) + if err := b.batch.Put(key, entry); err != nil { + return err + } + } + for k := range b.deletes { + key := []byte(k) + if err := b.batch.Delete(key); err != nil { + return err + } + } + + b.writes = make(map[string][]byte) + b.deletes = make(map[string]struct{}) + b.ranges = b.ranges[:0] + b.size = 0 + return nil +} + +func (b *snapshotBatch) Write() error { + if err := b.Flush(); err != nil { + return err + } + return b.batch.Write() +} + +func (b *snapshotBatch) Reset() { + b.writes = make(map[string][]byte) + b.deletes = make(map[string]struct{}) + b.ranges = b.ranges[:0] + b.size = 0 +} + +func (b *snapshotBatch) Get(key []byte, cb func([]byte) error) error { + if _, ok := b.deletes[string(key)]; ok { + return ErrKeyNotFound + } + if entry, ok := b.writes[string(key)]; ok { + return cb(entry) + } + if inRange(b.ranges, key) { + return ErrKeyNotFound + } + return b.snapshot.Get(key, cb) +} + +func (b *snapshotBatch) Has(key []byte) (bool, error) { + if entry, ok := b.writes[string(key)]; ok { + return entry != nil, nil + } + if _, ok := b.deletes[string(key)]; ok { + return false, nil + } + if inRange(b.ranges, key) { + return false, nil + } + return b.snapshot.Has(key) +} + +func inRange(ranges []deleteRange, key []byte) bool { + for _, r := range ranges { + if bytes.Compare(key, r.start) >= 0 && bytes.Compare(key, r.end) < 0 { + return true + } + } + return false +} + +func (b *snapshotBatch) NewIterator(prefix []byte, withUpperBound bool) (Iterator, error) { + return b.snapshot.NewIterator(prefix, withUpperBound) +} + +func (b *snapshotBatch) Close() error { + return b.batch.Close() +}