diff --git a/.travis.yml b/.travis.yml index 0ff1b42adb..432fdf12cf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,8 +17,8 @@ env: - TOX_POSARGS="-e py35-transactions" - TOX_POSARGS="-e py35-vm-fast" - TOX_POSARGS="-e py35-vm-limits" - - TOX_POSARGS="-e py35-leveldb" - TOX_POSARGS="-e py35-p2p" + - TOX_POSARGS="-e py35-database" #- TOX_POSARGS="-e py35-vm-performance" - TOX_POSARGS="-e flake8" cache: diff --git a/evm/chain.py b/evm/chain.py index 07c2a2827d..3058d1b736 100644 --- a/evm/chain.py +++ b/evm/chain.py @@ -44,7 +44,7 @@ ) from evm.utils.rlp import diff_rlp_object -from evm.state import State +from evm.db.state import State class Chain(object): diff --git a/evm/db/backends/base.py b/evm/db/backends/base.py index 205b0f1a13..6c67069937 100644 --- a/evm/db/backends/base.py +++ b/evm/db/backends/base.py @@ -19,19 +19,6 @@ def delete(self, key): "The `delete` method must be implemented by subclasses of BaseDB" ) - # - # Snapshot API - # - def snapshot(self): - raise NotImplementedError( - "The `snapshot` method must be implemented by subclasses of BaseDB" - ) - - def revert(self, snapshot): - raise NotImplementedError( - "The `revert` method must be implemented by subclasses of BaseDB" - ) - # # Dictionary API # diff --git a/evm/db/backends/level.py b/evm/db/backends/level.py index 13c75aa89d..b5740942ab 100644 --- a/evm/db/backends/level.py +++ b/evm/db/backends/level.py @@ -1,4 +1,5 @@ import shutil + from .base import ( BaseDB, ) @@ -32,29 +33,6 @@ def exists(self, key): def delete(self, key): self.db.Delete(key) - # - # Snapshot API - # - def snapshot(self): - return Snapshot(self.db.CreateSnapshot()) - - def revert(self, snapshot): - for item in self.db.RangeIter(include_value=False): - self.db.Delete(item) - for key, val in snapshot.items(): - self.db.Put(key, val) - # Clears the leveldb def __del__(self): shutil.rmtree(self.db_path, ignore_errors=True) - - -class Snapshot(object): - def __init__(self, snapshot): - self.db = snapshot - - def get(self, key): - return self.db.Get(key) - - def items(self): - return self.db.RangeIter(include_value=True, reverse=True) diff --git a/evm/db/backends/memory.py b/evm/db/backends/memory.py index 6dcbea86ab..945543cea9 100644 --- a/evm/db/backends/memory.py +++ b/evm/db/backends/memory.py @@ -1,4 +1,3 @@ -import copy from .base import ( BaseDB, ) @@ -21,12 +20,3 @@ def exists(self, key): def delete(self, key): del self.kv_store[key] - - # - # Snapshot API - # - def snapshot(self): - return copy.copy(self.kv_store) - - def revert(self, snapshot): - self.kv_store = snapshot diff --git a/evm/db/hash_trie.py b/evm/db/hash_trie.py new file mode 100644 index 0000000000..606baa992f --- /dev/null +++ b/evm/db/hash_trie.py @@ -0,0 +1,30 @@ +from evm.utils.keccak import ( + keccak, +) + + +class HashTrie(object): + _trie = None + + def __init__(self, trie): + self._trie = trie + + def __setitem__(self, key, value): + self._trie[keccak(key)] = value + + def __getitem__(self, key): + return self._trie[keccak(key)] + + def __delitem__(self, key): + del self._trie[keccak(key)] + + def __contains__(self, key): + return keccak(key) in self._trie + + @property + def root_hash(self): + return self._trie.root_hash + + @root_hash.setter + def root_hash(self, value): + self._trie.root_hash = value diff --git a/evm/db/journal.py b/evm/db/journal.py new file mode 100644 index 0000000000..e0381525fa --- /dev/null +++ b/evm/db/journal.py @@ -0,0 +1,225 @@ +import uuid + +from cytoolz import ( + merge, +) + +from evm.db.backends.base import BaseDB +from evm.exceptions import ValidationError + + +class Journal(object): + """ + A Journal is an ordered list of checkpoints. A checkpoint is a dictionary + of database keys and values. The values are the "original" value of that + key at the time the checkpoint was created. + + Checkpoints are referenced by a random uuid4. + """ + checkpoints = None + + def __init__(self): + # contains an array of `uuid4` instances + self.checkpoints = [] + # contains a mapping from all of the `uuid4` in the `checkpoints` array + # to a dictionary of key:value pairs wher the `value` is the original + # value for the given key at the moment this checkpoint was created. + self.journal_data = {} + + @property + def latest_id(self): + """ + Returns the checkpoint_id of the latest checkpoint + """ + return self.checkpoints[-1] + + @property + def latest(self): + """ + Returns the dictionary of db keys and values for the latest checkpoint. + """ + return self.journal_data[self.latest_id] + + @latest.setter + def latest(self, value): + """ + Setter for updating the *latest* checkpoint. + """ + self.journal_data[self.latest_id] = value + + def add(self, key, value): + """ + Adds the given key and value to the latest checkpoint. + """ + if not self.checkpoints: + # If no checkpoints exist we don't need to track history. + return + elif key in self.latest: + # If the key is already in the latest checkpoint we should not + # overwrite it. + return + self.latest[key] = value + + def create_checkpoint(self): + """ + Creates a new checkpoint. Checkpoints are referenced by a random uuid4 + to prevent collisions between multiple checkpoints. + """ + checkpoint_id = uuid.uuid4() + self.checkpoints.append(checkpoint_id) + self.journal_data[checkpoint_id] = {} + return checkpoint_id + + def pop_checkpoint(self, checkpoint_id): + """ + Returns all changes from the given checkpoint. This includes all of + the changes from any subsequent checkpoints, giving precidence to + earlier checkpoints. + """ + idx = self.checkpoints.index(checkpoint_id) + + # update the checkpoint list + checkpoint_ids = self.checkpoints[idx:] + self.checkpoints = self.checkpoints[:idx] + + # we pull all of the checkpoints *after* the checkpoint we are + # reverting to and collapse them to a single set of keys that need to + # be reverted (giving precidence to earlier checkpoints). + revert_data = merge(*( + self.journal_data.pop(c_id) + for c_id + in reversed(checkpoint_ids) + )) + + return dict(revert_data.items()) + + def commit_checkpoint(self, checkpoint_id): + """ + Collapses all changes for the givent checkpoint into the previous + checkpoint if it exists. + """ + changes_to_merge = self.pop_checkpoint(checkpoint_id) + if self.checkpoints: + # we only have to merge the changes into the latest checkpoint if + # there is one. + self.latest = merge( + changes_to_merge, + self.latest, + ) + + def __contains__(self, value): + return value in self.journal_data + + +class JournalDB(BaseDB): + """ + A wrapper around the basic DB objects that keeps a journal of all changes. + Each time a snapshot is taken, the underlying journal creates a new + checkpoint. The journal then keeps track of the original value for any + keys changed. Reverting to a checkpoint involves merging the original key + data from any subsequent checkpoints into the given checkpoint giving + precidence earlier checkpoints. Then the keys from this merged data set + are reset to their original values. + + The added memory footprint for a JournalDB is one key/value stored per + database key which is changed. Subsequent changes to the same key within + the same checkpoint will not increase the journal size since we only need + to track the original value for any given key within any given checkpoint. + """ + wrapped_db = None + journal = None + + def __init__(self, wrapped_db): + self.wrapped_db = wrapped_db + self.journal = Journal() + + def get(self, key): + return self.wrapped_db.get(key) + + def set(self, key, value): + """ + - replacing an existing value + - setting a value that does not exist + """ + try: + current_value = self.wrapped_db.get(key) + except KeyError: + current_value = None + + if current_value != value: + # only journal `set` operations that change the value. + self.journal.add(key, current_value) + + return self.wrapped_db.set(key, value) + + def exists(self, key): + return self.wrapped_db.exists(key) + + def delete(self, key): + try: + current_value = self.wrapped_db.get(key) + except KeyError: + # no state change so skip journaling + pass + else: + self.journal.add(key, current_value) + + return self.wrapped_db.delete(key) + + # + # Snapshot API + # + def _validate_checkpoint(self, checkpoint): + """ + Checks to be sure the checkpoint is known by the journal + """ + if checkpoint not in self.journal: + raise ValidationError("Checkpoint not found in journal: {0}".format( + str(checkpoint) + )) + + def snapshot(self): + """ + Takes a snapshot of the database by creating a checkpoint. + """ + return self.journal.create_checkpoint() + + def revert(self, checkpoint): + """ + Reverts the database back to the checkpoint. + """ + self._validate_checkpoint(checkpoint) + + for key, value in self.journal.pop_checkpoint(checkpoint).items(): + if value is None: + self.wrapped_db.delete(key) + else: + self.wrapped_db.set(key, value) + + def commit(self, checkpoint): + """ + Commits a given checkpoint. + """ + self._validate_checkpoint(checkpoint) + self.journal.commit_checkpoint(checkpoint) + + def clear(self): + """ + Cleare the entire journal. + """ + self.journal = Journal() + + # + # Dictionary API + # + def __getitem__(self, key): + return self.get(key) + + def __setitem__(self, key, value): + return self.set(key, value) + + def __delitem__(self, key): + return self.delete(key) + + def __contains__(self, key): + return self.exists(key) diff --git a/evm/state.py b/evm/db/state.py similarity index 85% rename from evm/state.py rename to evm/db/state.py index 463a09e9c5..5a9c04b559 100644 --- a/evm/state.py +++ b/evm/db/state.py @@ -29,36 +29,7 @@ pad32, ) - -class HashTrie(object): - _trie = None - - logger = logging.getLogger('evm.state.HashTrie') - - def __init__(self, trie): - self._trie = trie - - def __setitem__(self, key, value): - self._trie[keccak(key)] = value - - def __getitem__(self, key): - return self._trie[keccak(key)] - - def __delitem__(self, key): - del self._trie[keccak(key)] - - def __contains__(self, key): - return keccak(key) in self._trie - - @property - def root_hash(self): - return self._trie.root_hash - - def snapshot(self): - return self._trie.snapshot() - - def revert(self, snapshot): - return self._trie.revert(snapshot) +from .hash_trie import HashTrie class State(object): @@ -81,6 +52,10 @@ def __init__(self, db, root_hash=BLANK_ROOT_HASH): def root_hash(self): return self._trie.root_hash + @root_hash.setter + def root_hash(self, value): + self._trie.root_hash = value + def set_storage(self, address, slot, value): validate_uint256(value, title="Storage Value") validate_uint256(slot, title="Storage Slot") @@ -198,15 +173,6 @@ def increment_nonce(self, address): current_nonce = self.get_nonce(address) self.set_nonce(address, current_nonce + 1) - # - # Internal - # - def snapshot(self): - return self._trie.snapshot() - - def revert(self, snapshot): - return self._trie.revert(snapshot) - # # Internal # diff --git a/evm/vm/base.py b/evm/vm/base.py index ca669c0c79..bcc1acbc23 100644 --- a/evm/vm/base.py +++ b/evm/vm/base.py @@ -8,12 +8,15 @@ NEPHEW_REWARD, UNCLE_DEPTH_PENALTY_FACTOR, ) -from evm.logic.invalid import ( - InvalidOpcode, +from evm.db.journal import ( + JournalDB, ) -from evm.state import ( +from evm.db.state import ( State, ) +from evm.logic.invalid import ( + InvalidOpcode, +) from evm.utils.blocks import ( get_block_header_by_hash, @@ -37,6 +40,7 @@ def __init__(self, header, db): raise ValueError("VM classes must have a `db`") self.db = db + self.journal_db = JournalDB(self.db) block_class = self.get_block_class() self.block = block_class.from_header(header=header, db=db) @@ -59,7 +63,7 @@ def configure(cls, @contextmanager def state_db(self, read_only=False): - state = State(db=self.db, root_hash=self.block.header.state_root) + state = State(db=self.journal_db, root_hash=self.block.header.state_root) yield state if read_only: # TODO: This is a bit of a hack; ideally we should raise an error whenever the @@ -84,6 +88,7 @@ def apply_transaction(self, transaction): Apply the transaction to the vm in the current block. """ computation = self.execute_transaction(transaction) + self.clear_journal() self.block.add_transaction(transaction, computation) return computation @@ -255,19 +260,38 @@ def snapshot(self): """ Perform a full snapshot of the current state of the VM. - TODO: This needs to do more than just snapshot the state_db but this is a start. + Snapshots are a combination of the state_root at the time of the + snapshot and the checkpoint_id returned from the journaled DB. """ - with self.state_db(read_only=True) as state_db: - return state_db.snapshot() + return (self.block.header.state_root, self.journal_db.snapshot()) def revert(self, snapshot): """ - Revert the VM to the state - - TODO: This needs to do more than just snapshot the state_db but this is a start. + Revert the VM to the state at the snapshot """ + state_root, checkpoint_id = snapshot + with self.state_db() as state_db: - return state_db.revert(snapshot) + # first revert the database state root. + state_db.root_hash = state_root + # now roll the underlying database back + self.journal_db.revert(checkpoint_id) + + def commit(self, snapshot): + """ + Commits the journal to the point where the snapshot was taken. This + will destroy any journal checkpoints *after* the snapshot checkpoint. + """ + _, checkpoint_id = snapshot + self.journal_db.commit(checkpoint_id) + + def clear_journal(self): + """ + Cleare the journal. This should be called at any point of VM execution + where the statedb is being committed, such as after a transaction has + been applied to a block. + """ + self.journal_db.clear() # # Opcode API diff --git a/evm/vm/flavors/frontier/__init__.py b/evm/vm/flavors/frontier/__init__.py index 3a55ff6d84..5d9831c916 100644 --- a/evm/vm/flavors/frontier/__init__.py +++ b/evm/vm/flavors/frontier/__init__.py @@ -221,6 +221,9 @@ def _apply_frontier_message(vm, message): if computation.error: vm.revert(snapshot) + else: + vm.commit(snapshot) + return computation diff --git a/evm/vm/flavors/homestead/__init__.py b/evm/vm/flavors/homestead/__init__.py index f73dc69807..4479659b09 100644 --- a/evm/vm/flavors/homestead/__init__.py +++ b/evm/vm/flavors/homestead/__init__.py @@ -33,6 +33,7 @@ def _apply_homestead_create_message(vm, message): computation = vm.apply_message(message) if computation.error: + vm.commit(snapshot) return computation else: contract_code = computation.output @@ -56,8 +57,12 @@ def _apply_homestead_create_message(vm, message): encode_hex(message.storage_address), contract_code, ) + with vm.state_db() as state_db: state_db.set_code(message.storage_address, contract_code) + vm.commit(snapshot) + else: + vm.commit(snapshot) return computation diff --git a/setup.py b/setup.py index c77e969afa..32cec05c1a 100644 --- a/setup.py +++ b/setup.py @@ -26,14 +26,14 @@ include_package_data=True, py_modules=['evm'], install_requires=[ + "cryptography>=2.0.3", "cytoolz==0.8.2", "ethereum-bloom>=0.4.0", + "ethereum-keys==0.1.0a6", "ethereum-utils>=0.2.0", "pyethash>=0.1.27", "rlp==0.4.7", - "trie==0.2.4", - "ethereum-keys==0.1.0a6", - "cryptography>=2.0.3", + "trie>=0.3.0", ], extra_require={ 'leveldb': [ diff --git a/tests/database/test_journal_db.py b/tests/database/test_journal_db.py new file mode 100644 index 0000000000..56f6e84733 --- /dev/null +++ b/tests/database/test_journal_db.py @@ -0,0 +1,92 @@ +import pytest +from evm.db.backends.memory import MemoryDB +from evm.db.journal import JournalDB + + +@pytest.fixture +def journal_db(): + return JournalDB(MemoryDB()) + + +def test_set_and_get(journal_db): + journal_db.set(b'1', b'test') + + assert journal_db.get(b'1') == b'test' + + +def test_get_non_existent_value(journal_db): + with pytest.raises(KeyError): + journal_db.get(b'does-not-exist') + + +def test_delete_non_existent_value(journal_db): + with pytest.raises(KeyError): + journal_db.delete(b'does-not-exist') + + +def test_snapshot_and_revert_with_set(journal_db): + journal_db.set(b'1', b'test-a') + + assert journal_db.get(b'1') == b'test-a' + + snapshot = journal_db.snapshot() + + journal_db.set(b'1', b'test-b') + + assert journal_db.get(b'1') == b'test-b' + + journal_db.revert(snapshot) + + assert journal_db.get(b'1') == b'test-a' + + +def test_snapshot_and_revert_with_delete(journal_db): + journal_db.set(b'1', b'test-a') + + assert journal_db.exists(b'1') is True + assert journal_db.get(b'1') == b'test-a' + + snapshot = journal_db.snapshot() + + journal_db.delete(b'1') + + assert journal_db.exists(b'1') is False + + journal_db.revert(snapshot) + + assert journal_db.exists(b'1') is True + assert journal_db.get(b'1') == b'test-a' + + +def test_revert_clears_reverted_journal_entries(journal_db): + journal_db.set(b'1', b'test-a') + + assert journal_db.get(b'1') == b'test-a' + + snapshot_a = journal_db.snapshot() + + journal_db.set(b'1', b'test-b') + journal_db.delete(b'1') + journal_db.set(b'1', b'test-c') + + assert journal_db.get(b'1') == b'test-c' + + snapshot_b = journal_db.snapshot() + + journal_db.set(b'1', b'test-d') + journal_db.delete(b'1') + journal_db.set(b'1', b'test-e') + + assert journal_db.get(b'1') == b'test-e' + + journal_db.revert(snapshot_b) + + assert journal_db.get(b'1') == b'test-c' + + journal_db.delete(b'1') + + assert journal_db.exists(b'1') is False + + journal_db.revert(snapshot_a) + + assert journal_db.get(b'1') == b'test-a' diff --git a/tests/level-db/test_leveldb_db_backend.py b/tests/database/test_leveldb_db_backend.py similarity index 82% rename from tests/level-db/test_leveldb_db_backend.py rename to tests/database/test_leveldb_db_backend.py index 430dafd1f3..f7dfb20fb1 100644 --- a/tests/level-db/test_leveldb_db_backend.py +++ b/tests/database/test_leveldb_db_backend.py @@ -57,14 +57,3 @@ def test_delete(level_db, memory_db): level_db.delete(b'1') memory_db.delete(b'1') assert level_db.exists(b'1') == memory_db.exists(b'1') - - -def test_snapshot_and_revert(level_db): - snapshot = level_db.snapshot() - level_db.set(b'1', b'1') - assert level_db.get(b'1') - with pytest.raises(KeyError): - snapshot.get(b'1') - level_db.revert(snapshot) - with pytest.raises(KeyError): - level_db.get(b'1') diff --git a/tests/json-fixtures/test_blockchain.py b/tests/json-fixtures/test_blockchain.py index 351bfbe80a..9f41f44bf2 100644 --- a/tests/json-fixtures/test_blockchain.py +++ b/tests/json-fixtures/test_blockchain.py @@ -114,8 +114,8 @@ def test_blockchain_fixtures(fixture_name, fixture): # TODO: find out if this is supposed to pass? # if 'genesisRLP' in fixture: # assert rlp.encode(genesis_header) == fixture['genesisRLP'] - db = get_db_backend() + chain = MainnetChain # TODO: It would be great if we can figure out an API for re-configuring # start block numbers that was more elegant. diff --git a/tox.ini b/tox.ini index 7ddd4f156f..814d43567f 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] envlist= - py{27,34,35}-{core,p2p,leveldb,state-all,state-fast,state-slow,blockchain-fast,blockchain-slow,transactions,vm-all,vm-fast,vm-limits,vm-performance} + py{27,34,35}-{core,p2p,database,state-all,state-fast,state-slow,blockchain-fast,blockchain-slow,transactions,vm-all,vm-fast,vm-limits,vm-performance} flake8 [flake8] @@ -17,7 +17,7 @@ commands= blockchain-slow: py.test {posargs:tests/json-fixtures/test_blockchain.py -m blockchain_slow} core: py.test {posargs:tests/core} p2p: py.test {posargs:evm/p2p} - leveldb: py.test {posargs:tests/level-db} + database: py.test {posargs:tests/database} state-all: py.test {posargs:tests/json-fixtures/test_state.py} state-fast: py.test {posargs:tests/json-fixtures/test_state.py -m "not state_slow"} state-slow: py.test {posargs:tests/json-fixtures/test_state.py -m state_slow} @@ -28,8 +28,8 @@ commands= vm-performance: py.test {posargs:tests/json-fixtures/test_virtual_machine.py -m vm_performance} deps = -r{toxinidir}/requirements-dev.txt - coincurve: coincurve - leveldb: leveldb + coincurve + database: leveldb basepython = py27: python2.7 py34: python3.4