diff --git a/README.md b/README.md index 22077fcf..8837d0af 100644 --- a/README.md +++ b/README.md @@ -178,6 +178,7 @@ For example: ```sh nile send ownable0 transfer_ownership 0x07db6...60e794 +Calling transfer_ownership on ownable0 with params: ['0x07db6...60e794'] Invoke transaction was sent. Contract address: 0x07db6b52c8ab888183277bc6411c400136fe566c0eebfb96fffa559b2e60e794 Transaction hash: 0x1c @@ -218,13 +219,13 @@ Please note: ### `run` -Execute a script in the context of Nile. The script must implement a `run(nre)` function to receive a `NileRuntimeEnvironment` object exposing Nile's scripting API. +Execute a script in the context of Nile. The script must implement an asynchronous `run(nre)` function to receive a `NileRuntimeEnvironment` object exposing Nile's scripting API. ```python # path/to/script.py -def run(nre): - address, abi = nre.deploy("contract", alias="my_contract") +async def run(nre): + address, abi = await nre.deploy("contract", alias="my_contract") print(abi, address) ``` @@ -374,18 +375,18 @@ Retrieves a list of ready-to-use accounts which allows for easy scripting integr Next, write a script and call `get-accounts` to retrieve and use the deployed accounts. ```python -def run(nre): +async def run(nre): # fetch the list of deployed accounts - accounts = nre.get_accounts() + accounts = await nre.get_accounts() # then - accounts[0].send(...) + await accounts[0].send(...) # or alice, bob, *_ = accounts - alice.send(...) - bob.send(...) + await alice.send(...) + await bob.send(...) ``` > Please note that the list of accounts include only those that exist in the local `.accounts.json` file. diff --git a/setup.cfg b/setup.cfg index 531731df..c88f7dac 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,9 @@ python_requires = >=3.7 install_requires = click>=8.0,<9.0 + asyncclick>=8.0,<9.0 + anyio + cairo-lang importlib-metadata>=4.0 python-dotenv>=0.19.2 diff --git a/src/nile/cli.py b/src/nile/cli.py index 0853b3fb..c55077e6 100755 --- a/src/nile/cli.py +++ b/src/nile/cli.py @@ -2,7 +2,7 @@ """Nile CLI entry point.""" import logging -import click +import asyncclick as click from nile.core.account import Account from nile.core.call_or_invoke import call_or_invoke as call_or_invoke_command @@ -72,9 +72,9 @@ def install(): @cli.command() @click.argument("path", nargs=1) @network_option -def run(path, network): +async def run(path, network): """Run Nile scripts with NileRuntimeEnvironment.""" - run_command(path, network) + await run_command(path, network) @cli.command() @@ -82,26 +82,30 @@ def run(path, network): @click.argument("arguments", nargs=-1) @network_option @click.option("--alias") -def deploy(artifact, arguments, network, alias): +@click.option("--salt") +@click.option("--token") +async def deploy(artifact, arguments, network, alias, salt, token): """Deploy StarkNet smart contract.""" - deploy_command(artifact, arguments, network, alias) + return await deploy_command(artifact, arguments, network, alias, salt, token) @cli.command() @click.argument("artifact", nargs=1) @network_option @click.option("--alias") -def declare(artifact, network, alias): +@click.option("--signature", nargs=2) +@click.option("--token") +async def declare(artifact, network, alias, signature, token): """Declare StarkNet smart contract.""" - declare_command(artifact, network, alias) + return await declare_command(artifact, network, alias, signature, token) @cli.command() @click.argument("signer", nargs=1) @network_option -def setup(signer, network): +async def setup(signer, network): """Set up an Account contract.""" - Account(signer, network) + await Account(signer, network) @cli.command() @@ -109,18 +113,23 @@ def setup(signer, network): @click.argument("contract_name", nargs=1) @click.argument("method", nargs=1) @click.argument("params", nargs=-1) +@click.option("--nonce", nargs=1) @click.option("--max_fee", nargs=1) @network_option -def send(signer, contract_name, method, params, network, max_fee=None): +async def send(signer, contract_name, method, params, network, nonce, max_fee): """Invoke a contract's method through an Account. Same usage as nile invoke.""" - account = Account(signer, network) + account = await Account(signer, network) print( "Calling {} on {} with params: {}".format( method, contract_name, [x for x in params] ) ) - out = account.send(contract_name, method, params, max_fee=max_fee) - print(out) + address, tx_hash = await account.send( + contract_name, method, params, nonce=nonce, max_fee=max_fee + ) + logging.info("Invoke transaction was sent.") + logging.info(f"Contract address: {address}") + logging.info(f"Transaction hash: {tx_hash}") @cli.command() @@ -129,12 +138,14 @@ def send(signer, contract_name, method, params, network, max_fee=None): @click.argument("params", nargs=-1) @click.option("--max_fee", nargs=1) @network_option -def invoke(contract_name, method, params, network, max_fee=None): +async def invoke(contract_name, method, params, network, max_fee=None): """Invoke functions of StarkNet smart contracts.""" - out = call_or_invoke_command( + address, tx_hash = await call_or_invoke_command( contract_name, "invoke", method, params, network, max_fee=max_fee ) - print(out) + logging.info("Invoke transaction was sent.") + logging.info(f"Contract address: {address}") + logging.info(f"Transaction hash: {tx_hash}") @cli.command() @@ -142,10 +153,12 @@ def invoke(contract_name, method, params, network, max_fee=None): @click.argument("method", nargs=1) @click.argument("params", nargs=-1) @network_option -def call(contract_name, method, params, network): +async def call(contract_name, method, params, network): """Call functions of StarkNet smart contracts.""" - out = call_or_invoke_command(contract_name, "call", method, params, network) - print(out) + result = await call_or_invoke_command( + contract_name, "call", method, params, network + ) + logging.info(result) @cli.command() @@ -243,4 +256,4 @@ def get_accounts(network): if __name__ == "__main__": - cli() + cli(_anyio_backend="asyncio") diff --git a/src/nile/common.py b/src/nile/common.py index be069288..a18fba1b 100644 --- a/src/nile/common.py +++ b/src/nile/common.py @@ -2,7 +2,12 @@ import json import os import re -import subprocess + +from starkware.starknet.cli.starknet_cli import NETWORKS, assert_tx_received +from starkware.starknet.services.api.feeder_gateway.feeder_gateway_client import ( + FeederGatewayClient, +) +from starkware.starknet.services.api.gateway.gateway_client import GatewayClient CONTRACTS_DIRECTORY = "contracts" BUILD_DIRECTORY = "artifacts" @@ -46,30 +51,41 @@ def get_all_contracts(ext=None, directory=None): return files -def run_command( - contract_name, network, overriding_path=None, operation="deploy", arguments=None -): - """Execute CLI command with given parameters.""" - base_path = ( - overriding_path if overriding_path else (BUILD_DIRECTORY, ABIS_DIRECTORY) - ) - contract = f"{base_path[0]}/{contract_name}.json" - command = ["starknet", operation, "--contract", contract] - - if arguments: - command.append("--inputs") - command.extend(prepare_params(arguments)) - - if network == "mainnet": - os.environ["STARKNET_NETWORK"] = "alpha-mainnet" - elif network == "goerli": - os.environ["STARKNET_NETWORK"] = "alpha-goerli" +async def get_gateway_response(network, tx, token): + """Execute transaction and return response.""" + gateway_url = get_gateway_url(network) + gateway_client = GatewayClient(url=gateway_url) + gateway_response = await gateway_client.add_transaction(tx=tx, token=token) + assert_tx_received(gateway_response) + + return gateway_response + + +async def get_feeder_response(network, tx): + """Execute transaction and return response.""" + gateway_url = get_feeder_url(network) + gateway_client = FeederGatewayClient(url=gateway_url) + gateway_response = await gateway_client.call_contract(invoke_tx=tx) + + return gateway_response["result"] + + +def get_gateway_url(network): + """Return gateway URL for specified network.""" + if network == "localhost": + return GATEWAYS.get(network) else: - command.append(f"--gateway_url={GATEWAYS.get(network)}") + network = "alpha-" + network + return f"https://{NETWORKS[network]}/gateway" - command.append("--no_wallet") - return subprocess.check_output(command) +def get_feeder_url(network): + """Return feeder gateway URL for specified network.""" + if network == "localhost": + return GATEWAYS.get(network) + else: + network = "alpha-" + network + return f"https://{NETWORKS[network]}/feeder_gateway" def parse_information(x): @@ -92,3 +108,9 @@ def prepare_params(params): if params is None: params = [] return stringify(params) + + +def prepare_return(x): + """Unpack list and convert hex to integer.""" + for y in x: + return int(y, 16) diff --git a/src/nile/core/account.py b/src/nile/core/account.py index 89fcd49f..0450ef48 100644 --- a/src/nile/core/account.py +++ b/src/nile/core/account.py @@ -16,10 +16,24 @@ load_dotenv() -class Account: +class AsyncObject(object): + """Base class for Account to allow async initialization.""" + + async def __new__(cls, *a, **kw): + """Return coroutine (not class so sync __init__ is not invoked).""" + instance = super().__new__(cls) + await instance.__init__(*a, **kw) + return instance + + async def __init__(self): + """Support Account async __init__.""" + pass + + +class Account(AsyncObject): """Account contract abstraction.""" - def __init__(self, signer, network): + async def __init__(self, signer, network): """Get or deploy an Account contract for the given private key.""" try: self.signer = Signer(int(os.environ[signer])) @@ -33,27 +47,27 @@ def __init__(self, signer, network): ) return - if accounts.exists(str(self.signer.public_key), network): - signer_data = next(accounts.load(str(self.signer.public_key), network)) + if accounts.exists(str(self.signer.public_key), self.network): + signer_data = next(accounts.load(str(self.signer.public_key), self.network)) self.address = signer_data["address"] self.index = signer_data["index"] else: - address, index = self.deploy() + address, index = await self.deploy() self.address = address self.index = index - def deploy(self): + async def deploy(self): """Deploy an Account contract for the given private key.""" index = accounts.current_index(self.network) pt = os.path.dirname(os.path.realpath(__file__)).replace("/core", "") overriding_path = (f"{pt}/artifacts", f"{pt}/artifacts/abis") - address, _ = deploy( - "Account", - [str(self.signer.public_key)], - self.network, - f"account-{index}", - overriding_path, + address, _ = await deploy( + contract_name="Account", + arguments=[str(self.signer.public_key)], + network=self.network, + alias=f"account-{index}", + overriding_path=overriding_path, ) accounts.register( @@ -62,15 +76,13 @@ def deploy(self): return address, index - def send(self, to, method, calldata, max_fee, nonce=None): + async def send(self, to, method, calldata, max_fee=None, nonce=None): """Execute a tx going through an Account contract.""" target_address, _ = next(deployments.load(to, self.network)) or to calldata = [int(x) for x in calldata] if nonce is None: - nonce = int( - call_or_invoke(self.address, "call", "get_nonce", [], self.network)[0] - ) + nonce = await self.get_nonce() if max_fee is None: max_fee = 0 @@ -89,12 +101,22 @@ def send(self, to, method, calldata, max_fee, nonce=None): params.extend(calldata) params.append(nonce) - return call_or_invoke( + return await call_or_invoke( contract=self.address, type="invoke", method="__execute__", params=params, network=self.network, - signature=[str(sig_r), str(sig_s)], - max_fee=str(max_fee), + signature=[sig_r, sig_s], + max_fee=max_fee, + ) + + async def get_nonce(self): + """Return nonce from account contract.""" + return await call_or_invoke( + contract=self.address, + type="call", + method="get_nonce", + params=[], + network=self.network, ) diff --git a/src/nile/core/call_or_invoke.py b/src/nile/core/call_or_invoke.py index 6dc10c10..4afc2e08 100644 --- a/src/nile/core/call_or_invoke.py +++ b/src/nile/core/call_or_invoke.py @@ -1,71 +1,37 @@ """Command to call or invoke StarkNet smart contracts.""" -import logging -import os -import subprocess + +from starkware.starknet.definitions import constants +from starkware.starknet.public.abi import get_selector_from_name +from starkware.starknet.services.api.gateway.transaction import InvokeFunction +from starkware.starknet.utils.api_utils import cast_to_felts from nile import deployments -from nile.common import GATEWAYS, prepare_params +from nile.common import get_feeder_response, get_gateway_response, prepare_return -def call_or_invoke( - contract, type, method, params, network, signature=None, max_fee=None +async def call_or_invoke( + contract, type, method, params, network, signature=None, max_fee=None, token=None ): """Call or invoke functions of StarkNet smart contracts.""" address, abi = next(deployments.load(contract, network)) - command = [ - "starknet", - type, - "--address", - address, - "--abi", - abi, - "--function", - method, - ] - - if network == "mainnet": - os.environ["STARKNET_NETWORK"] = "alpha-mainnet" - elif network == "goerli": - os.environ["STARKNET_NETWORK"] = "alpha-goerli" + if max_fee is None: + max_fee = 0 + + tx = InvokeFunction( + contract_address=int(address, 16), + entry_point_selector=get_selector_from_name(method), + calldata=cast_to_felts(params or []), + max_fee=int(max_fee), + signature=cast_to_felts(signature or []), + version=constants.TRANSACTION_VERSION, + ) + + if type == "call": + response = await get_feeder_response(network, tx) + return prepare_return(response) + elif type == "invoke": + response = await get_gateway_response(network, tx, token) + return address, response["transaction_hash"] else: - gateway_prefix = "feeder_gateway" if type == "call" else "gateway" - command.append(f"--{gateway_prefix}_url={GATEWAYS.get(network)}") - - params = prepare_params(params) - - if len(params) > 0: - command.append("--inputs") - command.extend(params) - - if signature is not None: - command.append("--signature") - command.extend(signature) - - if max_fee is not None: - command.append("--max_fee") - command.append(max_fee) - - command.append("--no_wallet") - - try: - return subprocess.check_output(command).strip().decode("utf-8") - except subprocess.CalledProcessError: - p = subprocess.Popen(command, stderr=subprocess.PIPE) - _, error = p.communicate() - err_msg = error.decode() - - if "max_fee must be bigger than 0" in err_msg: - logging.error( - """ - \n😰 Whoops, looks like max fee is missing. Try with:\n - --max_fee=`MAX_FEE` - """ - ) - elif "transactions should go through the __execute__ entrypoint." in err_msg: - logging.error( - "\n\n😰 Whoops, looks like you're not using an account. Try with:\n" - "\nnile send [OPTIONS] SIGNER CONTRACT_NAME METHOD [PARAMS]" - ) - - return "" + raise TypeError(f"Unknown type '{type}', must be 'call' or 'invoke'") diff --git a/src/nile/core/declare.py b/src/nile/core/declare.py index f6712c6c..707610cb 100644 --- a/src/nile/core/declare.py +++ b/src/nile/core/declare.py @@ -1,11 +1,25 @@ """Command to declare StarkNet smart contracts.""" import logging +from starkware.starknet.definitions import constants +from starkware.starknet.services.api.contract_class import ContractClass +from starkware.starknet.services.api.gateway.transaction import ( + DECLARE_SENDER_ADDRESS, + Declare, +) + from nile import deployments -from nile.common import DECLARATIONS_FILENAME, parse_information, run_command +from nile.common import ( + ABIS_DIRECTORY, + BUILD_DIRECTORY, + DECLARATIONS_FILENAME, + get_gateway_response, +) -def declare(contract_name, network, alias=None, overriding_path=None): +async def declare( + contract_name, network, alias=None, signature=None, overriding_path=None, token=None +): """Declare StarkNet smart contracts.""" logging.info(f"🚀 Declaring {contract_name}") @@ -13,12 +27,32 @@ def declare(contract_name, network, alias=None, overriding_path=None): file = f"{network}.{DECLARATIONS_FILENAME}" raise Exception(f"Alias {alias} already exists in {file}") - output = run_command(contract_name, network, overriding_path, operation="declare") - class_hash, tx_hash = parse_information(output) + base_path = ( + overriding_path if overriding_path else (BUILD_DIRECTORY, ABIS_DIRECTORY) + ) + artifact = f"{base_path[0]}/{contract_name}.json" + open_artifact = open(artifact, "r") + contract_class = ContractClass.loads(data=open_artifact.read()) + + if signature is None: + signature = [] + + tx = Declare( + contract_class=contract_class, + sender_address=DECLARE_SENDER_ADDRESS, + signature=signature, + nonce=0, + max_fee=0, + version=constants.TRANSACTION_VERSION, + ) + + response = await get_gateway_response(network=network, tx=tx, token=token) + class_hash, tx_hash = response["class_hash"], response["transaction_hash"] + + deployments.register_class_hash(class_hash, network, alias) logging.info(f"⏳ Declaration of {contract_name} successfully sent at {class_hash}") logging.info(f"🧾 Transaction hash: {tx_hash}") - deployments.register_class_hash(class_hash, network, alias) return class_hash diff --git a/src/nile/core/deploy.py b/src/nile/core/deploy.py index ce293668..85049fee 100644 --- a/src/nile/core/deploy.py +++ b/src/nile/core/deploy.py @@ -1,23 +1,50 @@ """Command to deploy StarkNet smart contracts.""" import logging +from starkware.starknet.cli.starknet_cli import get_salt +from starkware.starknet.definitions import constants +from starkware.starknet.services.api.contract_class import ContractClass +from starkware.starknet.services.api.gateway.transaction import Deploy +from starkware.starknet.utils.api_utils import cast_to_felts + from nile import deployments -from nile.common import ABIS_DIRECTORY, BUILD_DIRECTORY, parse_information, run_command +from nile.common import ABIS_DIRECTORY, BUILD_DIRECTORY, get_gateway_response -def deploy(contract_name, arguments, network, alias, overriding_path=None): +async def deploy( + contract_name, + arguments, + network, + alias, + overriding_path=None, + salt=None, + token=None, +): """Deploy StarkNet smart contracts.""" logging.info(f"🚀 Deploying {contract_name}") + + args = cast_to_felts(arguments or []) base_path = ( overriding_path if overriding_path else (BUILD_DIRECTORY, ABIS_DIRECTORY) ) + artifact = f"{base_path[0]}/{contract_name}.json" abi = f"{base_path[1]}/{contract_name}.json" - output = run_command(contract_name, network, overriding_path, arguments=arguments) + open_artifact = open(artifact, "r") + contract_class = ContractClass.loads(data=open_artifact.read()) + + tx = Deploy( + contract_address_salt=get_salt(salt), + contract_definition=contract_class, + constructor_calldata=args, + version=constants.TRANSACTION_VERSION, + ) - address, tx_hash = parse_information(output) + response = await get_gateway_response(network=network, tx=tx, token=token) + address, tx_hash = response["address"], response["transaction_hash"] + + deployments.register(address, abi, network, alias) logging.info(f"⏳ ️Deployment of {contract_name} successfully sent at {address}") logging.info(f"🧾 Transaction hash: {tx_hash}") - deployments.register(address, abi, network, alias) return address, abi diff --git a/src/nile/core/run.py b/src/nile/core/run.py index 431a878e..ec28efb7 100644 --- a/src/nile/core/run.py +++ b/src/nile/core/run.py @@ -5,10 +5,10 @@ from nile.nre import NileRuntimeEnvironment -def run(path, network): +async def run(path, network): """Run nile scripts passing on the NRE object.""" logger = logging.getLogger() logger.disabled = True script = SourceFileLoader("script", path).load_module() nre = NileRuntimeEnvironment(network) - script.run(nre) + await script.run(nre) diff --git a/src/nile/nre.py b/src/nile/nre.py index e9fc638b..ebf58240 100644 --- a/src/nile/nre.py +++ b/src/nile/nre.py @@ -32,9 +32,7 @@ def deploy(self, contract, arguments=None, alias=None, overriding_path=None): def call(self, contract, method, params=None): """Call a view function in a smart contract.""" - return str( - call_or_invoke(contract, "call", method, params, self.network) - ).split() + return call_or_invoke(contract, "call", method, params, self.network) def invoke(self, contract, method, params=None): """Invoke a mutable function in a smart contract.""" diff --git a/src/nile/utils/get_accounts.py b/src/nile/utils/get_accounts.py index adb4c2b5..f09ebaed 100644 --- a/src/nile/utils/get_accounts.py +++ b/src/nile/utils/get_accounts.py @@ -6,7 +6,7 @@ from nile.core.account import Account -def get_accounts(network): +async def get_accounts(network): """Retrieve deployed accounts.""" try: total_accounts = current_index(network) @@ -27,15 +27,15 @@ def get_accounts(network): for i in range(total_accounts): logging.info(f"{i}: {addresses[i]}") - _account = _check_and_return_account(signers[i], pubkeys[i], network) + _account = await _check_and_return_account(signers[i], pubkeys[i], network) accounts.append(_account) logging.info("\n🚀 Successfully retrieved deployed accounts") return accounts -def _check_and_return_account(signer, pubkey, network): - account = Account(signer, network) +async def _check_and_return_account(signer, pubkey, network): + account = await Account(signer, network) assert str(pubkey) == str( account.signer.public_key ), "Signer pubkey does not match deployed pubkey" diff --git a/tests/commands/test_account.py b/tests/commands/test_account.py index a7f41606..33a9c851 100644 --- a/tests/commands/test_account.py +++ b/tests/commands/test_account.py @@ -1,6 +1,6 @@ """Tests for account commands.""" import logging -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import pytest @@ -18,20 +18,30 @@ def tmp_working_dir(monkeypatch, tmp_path): return tmp_path -@patch("nile.core.account.Account.deploy") -def test_account_init(mock_deploy): - mock_deploy.return_value = MOCK_ADDRESS, MOCK_INDEX - account = Account(KEY, NETWORK) +class AsyncMock(Mock): + """Return asynchronous mock.""" - assert account.address == MOCK_ADDRESS - assert account.index == MOCK_INDEX - mock_deploy.assert_called_once() + async def __call__(self, *args, **kwargs): + """Return mocked coroutine.""" + return super(AsyncMock, self).__call__(*args, **kwargs) -def test_account_init_bad_key(caplog): +@pytest.mark.asyncio +async def test_account_init(): + with patch("nile.core.account.Account.deploy", new=AsyncMock()) as mock_deploy: + mock_deploy.return_value = MOCK_ADDRESS, MOCK_INDEX + account = await Account(KEY, NETWORK) + + assert account.address == MOCK_ADDRESS + assert account.index == MOCK_INDEX + mock_deploy.assert_called_once() + + +@pytest.mark.asyncio +async def test_account_init_bad_key(caplog): logging.getLogger().setLevel(logging.INFO) - Account("BAD_KEY", NETWORK) + await Account("BAD_KEY", NETWORK) assert ( "\n❌ Cannot find BAD_KEY in env." "\nCheck spelling and that it exists." @@ -39,10 +49,11 @@ def test_account_init_bad_key(caplog): ) in caplog.text -def test_account_multiple_inits_with_same_key(): - account = Account(KEY, NETWORK) - account.deploy() - account2 = Account(KEY, NETWORK) +@pytest.mark.asyncio +async def test_account_multiple_inits_with_same_key(): + account = await Account(KEY, NETWORK) + await account.deploy() + account2 = await Account(KEY, NETWORK) # Check addresses don't match assert account.address != account2.address @@ -51,70 +62,79 @@ def test_account_multiple_inits_with_same_key(): assert account2.index == 1 -@patch("nile.core.account.deploy", return_value=(1, 2)) -def test_deploy(mock_deploy): - account = Account(KEY, NETWORK) - with patch("nile.core.account.os.path.dirname") as mock_path: - test_path = "/overriding_path" - mock_path.return_value.replace.return_value = test_path - - account.deploy() - - mock_deploy.assert_called_with( - "Account", - [str(account.signer.public_key)], - NETWORK, - f"account-{account.index + 1}", - (f"{test_path}/artifacts", f"{test_path}/artifacts/abis"), +@pytest.mark.asyncio +async def test_deploy(): + account = await Account(KEY, NETWORK) + with patch("nile.core.account.deploy", new=AsyncMock()) as mock_deploy: + mock_deploy.return_value = (1, 2) + with patch("nile.core.account.os.path.dirname") as mock_path: + test_path = "/overriding_path" + mock_path.return_value.replace.return_value = test_path + + await account.deploy() + + mock_deploy.assert_called_with( + alias=f"account-{account.index + 1}", + arguments=[str(account.signer.public_key)], + contract_name="Account", + network=NETWORK, + overriding_path=( + f"{test_path}/artifacts", + f"{test_path}/artifacts/abis", + ), + ) + + +@pytest.mark.asyncio +async def test_deploy_accounts_register(): + with patch("nile.core.account.deploy", new=AsyncMock()) as mock_deploy: + with patch("nile.core.account.accounts.register") as mock_register: + mock_deploy.return_value = (MOCK_ADDRESS, MOCK_INDEX) + account = await Account(KEY, NETWORK) + + mock_register.assert_called_once_with( + account.signer.public_key, MOCK_ADDRESS, MOCK_INDEX, KEY, NETWORK + ) + + +@pytest.mark.asyncio +async def test_send_nonce_call(): + account = await Account(KEY, NETWORK) + contract_address, _ = await account.deploy() + with patch("nile.core.account.Account.get_nonce", new=AsyncMock()) as mock_nonce: + + # Instead of creating and populating a tmp .txt file, this uses the + # deployed account address (contract_address) as the target + mock_nonce.return_value = 1 + await account.send( + to=contract_address, method="method", calldata=[1, 2, 3], max_fee=1 ) - -@patch("nile.core.account.deploy", return_value=(MOCK_ADDRESS, MOCK_INDEX)) -@patch("nile.core.account.accounts.register") -def test_deploy_accounts_register(mock_register, mock_deploy): - account = Account(KEY, NETWORK) - - mock_register.assert_called_once_with( - account.signer.public_key, MOCK_ADDRESS, MOCK_INDEX, KEY, NETWORK - ) - - -@patch("nile.core.account.call_or_invoke") -def test_send_nonce_call(mock_call): - account = Account(KEY, NETWORK) - contract_address, _ = account.deploy() - - # Instead of creating and populating a tmp .txt file, this uses the - # deployed account address (contract_address) as the target - account.send(contract_address, "method", [1, 2, 3], max_fee=1) - - # 'call_or_invoke' is called twice ('get_nonce' and '__execute__') - assert mock_call.call_count == 2 - - # Check 'get_nonce' call - mock_call.assert_any_call(account.address, "call", "get_nonce", [], NETWORK) + # Check 'get_nonce' call + mock_nonce.assert_called_once() +@pytest.mark.asyncio @pytest.mark.parametrize( "callarray, calldata", # The following callarray and calldata args tests the Account's list comprehensions # ensuring they're set to strings and passed correctly - [([["111"]], []), ([["111", "222"]], ["333", "444", "555"])], + [([[111]], []), ([[111, 222]], [333, 444, 555])], ) -def test_send_sign_transaction_and_execute(callarray, calldata): - account = Account(KEY, NETWORK) - contract_address, _ = account.deploy() +async def test_send_sign_transaction_and_execute(callarray, calldata): + account = await Account(KEY, NETWORK) + contract_address, _ = await account.deploy() sig_r, sig_s = [999, 888] return_signature = [callarray, calldata, sig_r, sig_s] account.signer.sign_transaction = MagicMock(return_value=return_signature) - with patch("nile.core.account.call_or_invoke") as mock_call: + with patch("nile.core.account.call_or_invoke", new=AsyncMock()) as mock_call: send_args = [contract_address, "method", [1, 2, 3]] nonce = 4 max_fee = 1 - account.send(*send_args, max_fee, nonce) + await account.send(*send_args, max_fee, nonce) # Check values are correctly passed to 'sign_transaction' account.signer.sign_transaction.assert_called_once_with( @@ -124,16 +144,31 @@ def test_send_sign_transaction_and_execute(callarray, calldata): # Check values are correctly passed to '__execute__' mock_call.assert_called_with( contract=account.address, - max_fee=str(max_fee), + max_fee=max_fee, method="__execute__", network=NETWORK, params=[ len(callarray), - *(str(elem) for sublist in callarray for elem in sublist), + *(elem for sublist in callarray for elem in sublist), len(calldata), - *(str(param) for param in calldata), + *(param for param in calldata), nonce, ], - signature=[str(sig_r), str(sig_s)], + signature=[sig_r, sig_s], type="invoke", ) + + +@pytest.mark.asyncio +async def test_nonce_call(): + account = await Account(KEY, NETWORK) + with patch("nile.core.account.call_or_invoke", new=AsyncMock()) as mock_call: + await account.get_nonce() + + mock_call.assert_called_once_with( + contract=account.address, + type="call", + method="get_nonce", + params=[], + network=account.network, + ) diff --git a/tests/commands/test_declare.py b/tests/commands/test_declare.py index e8eb1726..a2987b9e 100644 --- a/tests/commands/test_declare.py +++ b/tests/commands/test_declare.py @@ -1,10 +1,15 @@ """Tests for declare command.""" import logging -from unittest.mock import patch +from unittest.mock import Mock, mock_open, patch import pytest +from starkware.starknet.definitions import constants +from starkware.starknet.services.api.gateway.transaction import ( + DECLARE_SENDER_ADDRESS, + Declare, +) -from nile.common import DECLARATIONS_FILENAME +from nile.common import ABIS_DIRECTORY, BUILD_DIRECTORY, DECLARATIONS_FILENAME from nile.core.declare import alias_exists, declare @@ -17,10 +22,19 @@ def tmp_working_dir(monkeypatch, tmp_path): CONTRACT = "contract" NETWORK = "goerli" ALIAS = "alias" -PATH = "path" -RUN_OUTPUT = b"output" +PATH2 = "artifacts2" +PATH_OVERRIDE = (PATH2, ABIS_DIRECTORY) HASH = 111 TX_HASH = 222 +RESPONSE = dict({"class_hash": HASH, "transaction_hash": TX_HASH}) + + +class AsyncMock(Mock): + """Return asynchronous mock.""" + + async def __call__(self, *args, **kwargs): + """Return mocked coroutine.""" + return super(AsyncMock, self).__call__(*args, **kwargs) def test_alias_exists(): @@ -33,54 +47,98 @@ def test_alias_exists(): assert alias_exists(ALIAS, NETWORK) is True +@pytest.mark.asyncio +async def test_declare(caplog): + logging.getLogger().setLevel(logging.INFO) + + with patch( + "nile.core.declare.get_gateway_response", new=AsyncMock() + ) as mock_response: + mock_response.return_value = RESPONSE + + with patch("nile.core.declare.open", new_callable=mock_open): + with patch("nile.core.declare.ContractClass") as mock_contract_class: + res = await declare(contract_name=CONTRACT, network=NETWORK) + assert res == HASH, TX_HASH + + # check passed args to response + mock_response.assert_called_once_with( + network=NETWORK, + tx=Declare( + version=constants.TRANSACTION_VERSION, + contract_class=mock_contract_class.loads(), + sender_address=DECLARE_SENDER_ADDRESS, + max_fee=0, + signature=[], + nonce=0, + ), + token=None, + ) + + # check logs + assert f"🚀 Declaring {CONTRACT}" in caplog.text + assert ( + f"📦 Registering {HASH} in {NETWORK}.declarations.txt" in caplog.text + ) + assert ( + f"⏳ Declaration of {CONTRACT} successfully sent at {HASH}" + in caplog.text + ) + assert f"🧾 Transaction hash: {TX_HASH}" in caplog.text + + +@pytest.mark.asyncio @pytest.mark.parametrize( - "args, exp_command, exp_register", + "args, exp_register", [ ( - [CONTRACT, NETWORK], # args - [CONTRACT, NETWORK, None], # expected command + {"contract_name": CONTRACT, "network": NETWORK}, # args [HASH, NETWORK, None], # expected register ), ( - [CONTRACT, NETWORK, ALIAS], # args - [CONTRACT, NETWORK, None], # expected command + {"contract_name": CONTRACT, "network": NETWORK, "alias": ALIAS}, # args [HASH, NETWORK, ALIAS], # expected register ), ( - [CONTRACT, NETWORK, ALIAS, PATH], # args - [CONTRACT, NETWORK, PATH], # expected command + { + "contract_name": CONTRACT, + "network": NETWORK, + "alias": ALIAS, + "overriding_path": PATH_OVERRIDE, + }, # args [HASH, NETWORK, ALIAS], # expected register ), ], ) -@patch("nile.core.declare.run_command", return_value=RUN_OUTPUT) -@patch("nile.core.declare.parse_information", return_value=[HASH, TX_HASH]) -@patch("nile.core.declare.deployments.register_class_hash") -def test_declare( - mock_register, mock_parse, mock_run_cmd, caplog, args, exp_command, exp_register -): - logging.getLogger().setLevel(logging.INFO) +async def test_declare_register(args, exp_register): + with patch( + "nile.core.declare.get_gateway_response", new=AsyncMock() + ) as mock_response: + mock_response.return_value = RESPONSE - # check return value - res = declare(*args) - assert res == HASH + with patch("nile.core.declare.open", new_callable=mock_open) as m_open: + with patch("nile.core.declare.ContractClass"): + with patch( + "nile.core.declare.deployments.register_class_hash" + ) as mock_register: - # check internals - mock_run_cmd.assert_called_once_with(*exp_command, operation="declare") - mock_parse.assert_called_once_with(RUN_OUTPUT) - mock_register.assert_called_once_with(*exp_register) + await declare(**args) - # check logs - assert f"🚀 Declaring {CONTRACT}" in caplog.text - assert f"⏳ Declaration of {CONTRACT} successfully sent at {HASH}" in caplog.text - assert f"🧾 Transaction hash: {TX_HASH}" in caplog.text + # check overriding path + base_path = ( + PATH2 if "overriding_path" in args.keys() else BUILD_DIRECTORY + ) + m_open.assert_called_once_with(f"{base_path}/{CONTRACT}.json", "r") + # check registration + mock_register.assert_called_once_with(*exp_register) -@patch("nile.core.declare.alias_exists", return_value=True) -def test_declare_duplicate_hash(mock_alias_check): +@pytest.mark.asyncio +@patch("nile.core.declare.alias_exists", return_value=True) +async def test_declare_duplicate_hash(mock_alias_check): with pytest.raises(Exception) as err: - declare(ALIAS, NETWORK) + await declare(ALIAS, NETWORK) assert ( f"Alias {ALIAS} already exists in {NETWORK}.{DECLARATIONS_FILENAME}" diff --git a/tests/commands/test_deploy.py b/tests/commands/test_deploy.py index 0e04bb9e..e34ffa70 100644 --- a/tests/commands/test_deploy.py +++ b/tests/commands/test_deploy.py @@ -1,8 +1,10 @@ """Tests for deploy command.""" import logging -from unittest.mock import patch +from unittest.mock import Mock, mock_open, patch import pytest +from starkware.starknet.definitions import constants +from starkware.starknet.services.api.gateway.transaction import Deploy from nile.core.deploy import ABIS_DIRECTORY, BUILD_DIRECTORY, deploy @@ -13,47 +15,105 @@ def tmp_working_dir(monkeypatch, tmp_path): return tmp_path +class AsyncMock(Mock): + """Return asynchronous mock.""" + + async def __call__(self, *args, **kwargs): + """Return mocked coroutine.""" + return super(AsyncMock, self).__call__(*args, **kwargs) + + CONTRACT = "contract" NETWORK = "goerli" ALIAS = "alias" ABI = f"{ABIS_DIRECTORY}/{CONTRACT}.json" -BASE_PATH = (BUILD_DIRECTORY, ABIS_DIRECTORY) -PATH_OVERRIDE = ("artifacts2", ABIS_DIRECTORY) -RUN_OUTPUT = b"output" +PATH2 = "artifacts2" +PATH_OVERRIDE = (PATH2, ABIS_DIRECTORY) ARGS = [1, 2, 3] ADDRESS = 999 TX_HASH = 222 +SALT = "0x123" +RESPONSE = dict({"address": ADDRESS, "transaction_hash": TX_HASH}) + + +@pytest.mark.asyncio +async def test_deploy(caplog): + logging.getLogger().setLevel(logging.INFO) + + with patch( + "nile.core.deploy.get_gateway_response", new=AsyncMock() + ) as mock_response: + mock_response.return_value = RESPONSE + with patch("nile.core.deploy.open", new_callable=mock_open): + with patch("nile.core.deploy.ContractClass") as mock_contract_class: + res = await deploy( + contract_name=CONTRACT, + arguments=ARGS, + network=NETWORK, + alias=ALIAS, + salt=SALT, + ) + + # check return values + assert res == (ADDRESS, ABI) + + # check response + mock_response.assert_called_once_with( + network=NETWORK, + tx=Deploy( + version=constants.TRANSACTION_VERSION, + contract_address_salt=int(SALT, 16), + contract_definition=mock_contract_class.loads(), + constructor_calldata=ARGS, + ), + token=None, + ) + + # check logs + assert f"🚀 Deploying {CONTRACT}" in caplog.text + assert f"⏳ ️Deployment of {CONTRACT} successfully sent at {ADDRESS}" in caplog.text + assert f"🧾 Transaction hash: {TX_HASH}" in caplog.text +@pytest.mark.asyncio @pytest.mark.parametrize( - "args, exp_command", + "args, exp_register", [ ( - [CONTRACT, ARGS, NETWORK, ALIAS], # args - [CONTRACT, NETWORK, None], # expected command + { + "contract_name": CONTRACT, + "arguments": ARGS, + "network": NETWORK, + "alias": None, + }, + [ADDRESS, ABI, NETWORK, None], # expected register ), ( - [CONTRACT, ARGS, NETWORK, ALIAS, PATH_OVERRIDE], # args - [CONTRACT, NETWORK, PATH_OVERRIDE], # expected command + { + "contract_name": CONTRACT, + "arguments": ARGS, + "network": NETWORK, + "alias": ALIAS, + "overriding_path": PATH_OVERRIDE, + }, + [ADDRESS, ABI, NETWORK, ALIAS], # expected register ), ], ) -@patch("nile.core.deploy.run_command", return_value=RUN_OUTPUT) -@patch("nile.core.deploy.parse_information", return_value=[ADDRESS, TX_HASH]) -@patch("nile.core.deploy.deployments.register") -def test_deploy(mock_register, mock_parse, mock_run_cmd, caplog, args, exp_command): - logging.getLogger().setLevel(logging.INFO) +async def test_deploy_registration(args, exp_register): + with patch( + "nile.core.deploy.get_gateway_response", new=AsyncMock() + ) as mock_response: + mock_response.return_value = RESPONSE + with patch("nile.core.deploy.open", new_callable=mock_open) as m_open: + with patch("nile.core.deploy.ContractClass"): + with patch("nile.core.deploy.deployments.register") as mock_register: - # check return values - res = deploy(*args) - assert res == (ADDRESS, ABI) + await deploy(**args) - # check internals - mock_run_cmd.assert_called_once_with(*exp_command, arguments=ARGS) - mock_parse.assert_called_once_with(RUN_OUTPUT) - mock_register.assert_called_once_with(ADDRESS, ABI, NETWORK, ALIAS) - - # check logs - assert f"🚀 Deploying {CONTRACT}" in caplog.text - assert f"⏳ ️Deployment of {CONTRACT} successfully sent at {ADDRESS}" in caplog.text - assert f"🧾 Transaction hash: {TX_HASH}" in caplog.text + # check overriding path + base_path = ( + PATH2 if "overriding_path" in args.keys() else BUILD_DIRECTORY + ) + m_open.assert_called_once_with(f"{base_path}/{CONTRACT}.json", "r") + mock_register.assert_called_once_with(*exp_register) diff --git a/tests/commands/test_get_accounts.py b/tests/commands/test_get_accounts.py index 0205f844..832359e8 100644 --- a/tests/commands/test_get_accounts.py +++ b/tests/commands/test_get_accounts.py @@ -1,6 +1,6 @@ """Tests for get-accounts command.""" import logging -from unittest.mock import MagicMock, patch +from unittest.mock import Mock, patch import pytest @@ -30,6 +30,14 @@ } +class AsyncMock(Mock): + """Return asynchronous mock.""" + + async def __call__(self, *args, **kwargs): + """Return mocked coroutine.""" + return super(AsyncMock, self).__call__(*args, **kwargs) + + @pytest.fixture(autouse=True) def tmp_working_dir(monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) @@ -42,6 +50,7 @@ def mock_subprocess(): yield mock_subprocess +@pytest.mark.asyncio @pytest.mark.parametrize( "private_keys, public_keys", [ @@ -49,13 +58,14 @@ def mock_subprocess(): ([ALIASES[1], PUBKEYS[1]]), ], ) -def test__check_and_return_account_with_matching_keys(private_keys, public_keys): +async def test__check_and_return_account_with_matching_keys(private_keys, public_keys): # Check matching public/private keys - account = _check_and_return_account(private_keys, public_keys, NETWORK) + account = await _check_and_return_account(private_keys, public_keys, NETWORK) assert type(account) is Account +@pytest.mark.asyncio @pytest.mark.parametrize( "private_keys, public_keys", [ @@ -63,16 +73,19 @@ def test__check_and_return_account_with_matching_keys(private_keys, public_keys) ([ALIASES[1], PUBKEYS[0]]), ], ) -def test__check_and_return_account_with_mismatching_keys(private_keys, public_keys): +async def test__check_and_return_account_with_mismatching_keys( + private_keys, public_keys +): # Check mismatched public/private keys with pytest.raises(AssertionError) as err: - _check_and_return_account(private_keys, public_keys, NETWORK) + await _check_and_return_account(private_keys, public_keys, NETWORK) assert "Signer pubkey does not match deployed pubkey" in str(err.value) -def test_get_accounts_no_activated_accounts_feedback(capsys): - get_accounts(NETWORK) +@pytest.mark.asyncio +async def test_get_accounts_no_activated_accounts_feedback(capsys): + await get_accounts(NETWORK) # This test uses capsys in order to test the print statements (instead of logging) captured = capsys.readouterr() @@ -85,42 +98,46 @@ def test_get_accounts_no_activated_accounts_feedback(capsys): ) -@patch("nile.utils.get_accounts.current_index", MagicMock(return_value=len(PUBKEYS))) -@patch("nile.utils.get_accounts.open", MagicMock()) -@patch("nile.utils.get_accounts.json.load", MagicMock(return_value=MOCK_ACCOUNTS)) -def test_get_accounts_activated_accounts_feedback(caplog): +@pytest.mark.asyncio +async def test_get_accounts_activated_accounts_feedback(caplog): logging.getLogger().setLevel(logging.INFO) - # Default argument - get_accounts(NETWORK) + with patch("nile.utils.get_accounts.current_index", return_value=len(PUBKEYS)): + with patch("nile.utils.get_accounts.open"): + with patch("nile.utils.get_accounts.json.load", return_value=MOCK_ACCOUNTS): - # Check total accounts log - assert f"\nTotal registered accounts: {len(PUBKEYS)}\n" in caplog.text + # Default argument + await get_accounts(NETWORK) - # Check index/address log - for i in range(len(PUBKEYS)): - assert f"{INDEXES[i]}: {ADDRESSES[i]}" in caplog.text + # Check total accounts log + assert f"\nTotal registered accounts: {len(PUBKEYS)}\n" in caplog.text - # Check final success log - assert "\n🚀 Successfully retrieved deployed accounts" in caplog.text + # Check index/address log + for i in range(len(PUBKEYS)): + assert f"{INDEXES[i]}: {ADDRESSES[i]}" in caplog.text + # Check final success log + assert "\n🚀 Successfully retrieved deployed accounts" in caplog.text -@patch("nile.utils.get_accounts.current_index", MagicMock(return_value=len(PUBKEYS))) -@patch("nile.utils.get_accounts.open", MagicMock()) -@patch("nile.utils.get_accounts.json.load", MagicMock(return_value=MOCK_ACCOUNTS)) -def test_get_accounts_with_keys(): - with patch( - "nile.utils.get_accounts._check_and_return_account" - ) as mock_return_account: - result = get_accounts(NETWORK) +@pytest.mark.asyncio +async def test_get_accounts_with_keys(): + with patch("nile.utils.get_accounts.current_index", return_value=len(PUBKEYS)): + with patch("nile.utils.get_accounts.open"): + with patch("nile.utils.get_accounts.json.load", return_value=MOCK_ACCOUNTS): + with patch( + "nile.utils.get_accounts._check_and_return_account", new=AsyncMock() + ) as mock_return_account: + result = await get_accounts(NETWORK) - # Check correct args are passed to `_check_and_receive_account` - for i in range(len(PUBKEYS)): - mock_return_account.assert_any_call(ALIASES[i], PUBKEYS[i], NETWORK) + # Check correct args are passed to `_check_and_receive_account` + for i in range(len(PUBKEYS)): + mock_return_account.assert_any_call( + ALIASES[i], PUBKEYS[i], NETWORK + ) - # Assert call count equals correct number of accounts - assert mock_return_account.call_count == len(PUBKEYS) + # Assert call count equals correct number of accounts + assert mock_return_account.call_count == len(PUBKEYS) - # assert returned accounts array equals correct number of accounts - assert len(result) == len(PUBKEYS) + # assert returned accounts array equals correct number of accounts + assert len(result) == len(PUBKEYS) diff --git a/tests/conftest.py b/tests/conftest.py index b86be8d6..15c81542 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1 +1,9 @@ """Test configuration for pytest.""" +import asyncio + +import pytest + + +@pytest.fixture(scope="module") +def event_loop(): + return asyncio.new_event_loop() diff --git a/tests/test_cli.py b/tests/test_cli.py index e7b8a3f0..435e719a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -9,12 +9,12 @@ from signal import SIGINT from threading import Timer from time import sleep -from unittest.mock import patch +from unittest.mock import Mock, patch from urllib.error import URLError from urllib.request import urlopen import pytest -from click.testing import CliRunner +from asyncclick.testing import CliRunner from nile.cli import cli from nile.common import ( @@ -31,6 +31,14 @@ pytestmark = pytest.mark.end_to_end +class AsyncMock(Mock): + """Return asynchronous mock.""" + + async def __call__(self, *args, **kwargs): + """Return mocked coroutine.""" + return super(AsyncMock, self).__call__(*args, **kwargs) + + @pytest.fixture(autouse=True) def tmp_working_dir(monkeypatch, tmp_path): monkeypatch.chdir(tmp_path) @@ -43,11 +51,12 @@ def create_process(target, args): return p -def start_node(seconds, node_args): +@pytest.mark.asyncio +async def start_node(seconds, node_args): """Start node with host and port specified in node_args and life in seconds.""" # Timed kill command with SIGINT to close Node process Timer(seconds, lambda: kill(getpid(), SIGINT)).start() - CliRunner().invoke(cli, ["node", *node_args]) + await CliRunner().invoke(cli, ["node", *node_args]) def check_node(p, seconds, gateway_url): @@ -63,13 +72,15 @@ def check_node(p, seconds, gateway_url): continue -def test_clean(): +@pytest.mark.asyncio +async def test_clean(): # The implementation is already thoroughly covered by unit tests, so here # we just check whether the command completes successfully. - result = CliRunner().invoke(cli, ["clean"]) + result = await CliRunner().invoke(cli, ["clean"]) assert result.exit_code == 0 +@pytest.mark.asyncio @pytest.mark.parametrize( "args, expected", [ @@ -83,7 +94,7 @@ def test_clean(): reason="Issue in cairo-lang. " "See https://github.com/starkware-libs/cairo-lang/issues/27", ) -def test_compile(args, expected): +async def test_compile(args, expected): contract_source = RESOURCES_DIR / "contracts" / "contract.cairo" target_dir = Path(CONTRACTS_DIRECTORY) @@ -98,34 +109,35 @@ def test_compile(args, expected): assert not abi_dir.exists() assert not build_dir.exists() - result = CliRunner().invoke(cli, ["compile", *args]) + result = await CliRunner().invoke(cli, ["compile", *args]) assert result.exit_code == 0 assert {f.name for f in abi_dir.glob("*.json")} == expected assert {f.name for f in build_dir.glob("*.json")} == expected +@pytest.mark.asyncio @pytest.mark.xfail( sys.version_info >= (3, 10), reason="Issue in cairo-lang. " "See https://github.com/starkware-libs/cairo-lang/issues/27", ) -@patch("nile.core.node.subprocess") -def test_node_forwards_args(mock_subprocess): - args = [ - "--host", - "localhost", - "--port", - "5001", - "--seed", - "1234", - ] - - result = CliRunner().invoke(cli, ["node", *args]) - assert result.exit_code == 0 +async def test_node_forwards_args(): + with patch("nile.core.node.subprocess") as mock_subprocess: + args = [ + "--host", + "localhost", + "--port", + "5001", + "--seed", + "1234", + ] + + result = await CliRunner().invoke(cli, ["node", *args]) + assert result.exit_code == 0 - expected = ["starknet-devnet", *args] - mock_subprocess.check_call.assert_called_once_with(expected) + expected = ["starknet-devnet", *args] + mock_subprocess.check_call.assert_called_once_with(expected) @pytest.mark.parametrize( @@ -141,7 +153,7 @@ def test_node_forwards_args(mock_subprocess): reason="Issue in cairo-lang. " "See https://github.com/starkware-libs/cairo-lang/issues/27", ) -def test_node_runs_gateway(opts, expected): +async def test_node_runs_gateway(opts, expected): # Node life seconds = 15 @@ -162,7 +174,8 @@ def test_node_runs_gateway(opts, expected): # Spawn process to start StarkNet local network with specified port # i.e. $ nile node --host localhost --port 5001 - p = create_process(target=start_node, args=(seconds, args)) + init_node = await start_node(seconds, args) + p = create_process(target=init_node, args=(seconds, args)) p.start() # Check node heartbeat and assert that it is running @@ -177,6 +190,7 @@ def test_node_runs_gateway(opts, expected): assert gateway.get(network) == expected +@pytest.mark.asyncio @pytest.mark.parametrize( "args", [ @@ -184,17 +198,19 @@ def test_node_runs_gateway(opts, expected): ([MOCK_HASH, "--network", "mainnet", "--contracts_file", "example.txt"]), ], ) -@patch("nile.utils.debug.subprocess") -def test_debug(mock_subprocess, args): - # debug will hang without patch - mock_subprocess.check_output.return_value = json.dumps({"tx_status": "ACCEPTED"}) +async def test_debug(args): + with patch("nile.utils.debug.subprocess") as mock_subprocess: + # debug will hang without patch + mock_subprocess.check_output.return_value = json.dumps( + {"tx_status": "ACCEPTED"} + ) - result = CliRunner().invoke(cli, ["debug", *args]) + result = await CliRunner().invoke(cli, ["debug", *args]) - # Check status - assert result.exit_code == 0 + # Check status + assert result.exit_code == 0 - # Setup and assert expected output - expected = ["starknet", "tx_status", "--hash", MOCK_HASH] + # Setup and assert expected output + expected = ["starknet", "tx_status", "--hash", MOCK_HASH] - mock_subprocess.check_output.assert_called_once_with(expected) + mock_subprocess.check_output.assert_called_once_with(expected) diff --git a/tests/test_common.py b/tests/test_common.py index f8944ded..365c290b 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,9 +1,15 @@ """Tests for common library.""" -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest +from starkware.starkware_utils.error_handling import StarkErrorCode -from nile.common import BUILD_DIRECTORY, prepare_params, run_command, stringify +from nile.common import ( + get_feeder_response, + get_gateway_response, + prepare_params, + stringify, +) CONTRACT = "contract" OPERATION = "invoke" @@ -14,25 +20,62 @@ LIST3 = [1, 2, 3, [4, 5, 6, [7, 8, 9]]] -@pytest.mark.parametrize("operation", ["invoke", "call"]) -@patch("nile.common.subprocess.check_output") -def test_run_command(mock_subprocess, operation): - - run_command( - contract_name=CONTRACT, network=NETWORK, operation=operation, arguments=ARGS - ) - - mock_subprocess.assert_called_once_with( - [ - "starknet", - operation, - "--contract", - f"{BUILD_DIRECTORY}/{CONTRACT}.json", - "--inputs", - *ARGS, - "--no_wallet", - ] - ) +TX_RECEIVED = dict({"code": StarkErrorCode.TRANSACTION_RECEIVED.name, "result": "test"}) +TX_FAILED = dict({"code": "test"}) + + +class AsyncMock(Mock): + """Return asynchronous mock.""" + + async def __call__(self, *args, **kwargs): + """Return mocked coroutine.""" + return super(AsyncMock, self).__call__(*args, **kwargs) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "success, tx_response", + [ + (False, TX_FAILED), + (True, TX_RECEIVED), + ], +) +async def test_get_gateway_response(success, tx_response): + with patch("nile.core.call_or_invoke.InvokeFunction") as mock_tx: + with patch( + "nile.common.GatewayClient.add_transaction", new=AsyncMock() + ) as mock_client: + mock_client.return_value = tx_response + args = dict({"network": NETWORK, "tx": mock_tx, "token": None}) + + if success: + # success + res = await get_gateway_response(**args) + assert res == tx_response + + else: + mock_client.return_value = tx_response + + with pytest.raises(BaseException) as err: + await get_gateway_response(**args) + assert "Failed to send transaction. Response: {'code': 'test'}." in str( + err.value + ) + + mock_client.assert_called_once_with(tx=mock_tx, token=None) + + +@pytest.mark.asyncio +async def test_get_feeder_response(): + with patch("nile.core.call_or_invoke.InvokeFunction") as mock_tx: + with patch( + "nile.common.FeederGatewayClient.call_contract", new=AsyncMock() + ) as mock_client: + mock_client.return_value = TX_RECEIVED + args = dict({"network": NETWORK, "tx": mock_tx}) + + res = await get_feeder_response(**args) + assert res == TX_RECEIVED["result"] @pytest.mark.parametrize( diff --git a/tox.ini b/tox.ini index b6f4f7d3..7c5cc5bf 100644 --- a/tox.ini +++ b/tox.ini @@ -14,6 +14,7 @@ passenv = extras = testing deps = + asyncclick cairo-lang==0.9.1 starknet-devnet # See https://github.com/starkware-libs/cairo-lang/issues/52