Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for enable_decimals for Vyper 0.4 #133

Merged
merged 2 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ Import the voting contract types like this:
import voting.ballot as ballot
```

### Decimals

To use decimals on Vyper 0.4, use the following config:

```yaml
vyper:
enable_decimals: true
```

### Pragmas

Ape-Vyper supports Vyper 0.3.10's [new pragma formats](https://github.com/vyperlang/vyper/pull/3493)
Expand Down
7 changes: 5 additions & 2 deletions tests/ape-config.yaml → ape-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Allows compiling to work from the project-level.
contracts_folder: contracts/passing_contracts
contracts_folder: tests/contracts/passing_contracts

# Specify a dependency to use in Vyper imports.
dependencies:
- name: exampledependency
local: ./ExampleDependency
local: ./tests/ExampleDependency

# NOTE: Snekmate does not need to be listed here since
# it is installed in site-packages. However, we include it
# to show it doesn't cause problems when included.
- python: snekmate
config_override:
contracts_folder: .

vyper:
enable_decimals: true
3 changes: 1 addition & 2 deletions ape_vyper/compiler/_versions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ def get_settings(
optimization = False

selection_dict = self._get_selection_dictionary(selection, project=pm)
search_paths = [*getsitepackages()]
search_paths.append(".")
search_paths = [*getsitepackages(), "."]

version_settings[settings_key] = {
"optimize": optimization,
Expand Down
19 changes: 19 additions & 0 deletions ape_vyper/compiler/_versions/vyper_04.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,25 @@ def get_import_remapping(self, project: Optional[ProjectManager] = None) -> dict
# You always import via module or package name.
return {}

def get_settings(
self,
version: Version,
source_paths: Iterable[Path],
compiler_data: dict,
project: Optional[ProjectManager] = None,
) -> dict:
pm = project or self.local_project

enable_decimals = self.api.get_config(project=pm).enable_decimals
if enable_decimals is None:
enable_decimals = False

settings = super().get_settings(version, source_paths, compiler_data, project=pm)
for settings_set in settings.values():
settings_set["enable_decimals"] = enable_decimals

return settings

def _get_sources_dictionary(
self, source_ids: Iterable[str], project: Optional[ProjectManager] = None, **kwargs
) -> dict[str, dict]:
Expand Down
16 changes: 12 additions & 4 deletions ape_vyper/compiler/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,17 @@ def compile(
settings: Optional[dict] = None,
) -> Iterator[ContractType]:
pm = project or self.local_project

original_settings = self.compiler_settings
self.compiler_settings = {**self.compiler_settings, **(settings or {})}
try:
yield from self._compile(contract_filepaths, project=pm)
finally:
self.compiler_settings = original_settings

def _compile(
self, contract_filepaths: Iterable[Path], project: Optional[ProjectManager] = None
):
pm = project or self.local_project
contract_types: list[ContractType] = []
import_map = self._import_resolver.get_imports(pm, contract_filepaths)
config = self.get_config(pm)
Expand Down Expand Up @@ -514,12 +523,11 @@ def init_coverage_profile(
def enrich_error(self, err: ContractLogicError) -> ContractLogicError:
return enrich_error(err)

# TODO: In 0.9, make sure project is a kwarg here.
def trace_source(
self, contract_source: ContractSource, trace: TraceAPI, calldata: HexBytes
) -> SourceTraceback:
frames = trace.get_raw_frames()
tracer = SourceTracer(contract_source, frames, calldata)
return tracer.trace()
return SourceTracer.trace(trace.get_raw_frames(), contract_source, calldata)

def _get_compiler_arguments(
self,
Expand Down
7 changes: 7 additions & 0 deletions ape_vyper/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class VyperConfig(PluginConfig):

"""

enable_decimals: Optional[bool] = None
"""
On Vyper 0.4, to use decimal types, you must enable it.
Defaults to ``None`` to avoid misleading that ``False``
means you cannot use decimals on a lower version.
"""

@field_validator("version", mode="before")
def validate_version(cls, value):
return pragma_str_to_specifier_set(value) if isinstance(value, str) else value
Expand Down
7 changes: 5 additions & 2 deletions ape_vyper/flattener.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ape.logging import logger
from ape.managers import ProjectManager
from ape.utils import ManagerAccessMixin
from ape.utils import ManagerAccessMixin, get_relative_path
from ethpm_types.source import Content

from ape_vyper._utils import get_version_pragma_spec
Expand Down Expand Up @@ -65,7 +65,10 @@ def _flatten_source(
flattened_modules = ""
modules_prefixes: set[str] = set()

for import_path in sorted(imports):
# Source by source ID for greater consistency..
for import_path in sorted(
imports, key=lambda p: f"{get_relative_path(p.absolute(), pm.path)}"
):
import_info = imports[import_path]

# Vyper imported interface names come from their file names
Expand Down
26 changes: 6 additions & 20 deletions ape_vyper/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def __init__(self, project: ProjectManager, paths: list[Path]):
# Even though we build up mappings of all sources, as may be referenced
# later on and that prevents re-calculating over again, we only
# "show" the items requested.
self._request_view: list[Path] = paths
self.paths: list[Path] = paths

def __getitem__(self, item: Union[str, Path], *args, **kwargs) -> list[Import]:
if isinstance(item, str) or not item.is_absolute():
Expand Down Expand Up @@ -294,7 +294,7 @@ def keys(self) -> list[Path]: # type: ignore
result = []
keys = sorted(list(super().keys()))
for path in keys:
if path not in self._request_view:
if path not in self.paths:
continue

result.append(path)
Expand All @@ -311,7 +311,7 @@ def values(self) -> list[list[Import]]: # type: ignore
def items(self) -> list[tuple[Path, list[Import]]]: # type: ignore
result = []
for path in self.keys(): # sorted
if path not in self._request_view:
if path not in self.paths:
continue

result.append((path, self[path]))
Expand All @@ -328,30 +328,16 @@ class ImportResolver(ManagerAccessMixin):
_projects: dict[str, ImportMap] = {}
_dependency_attempted_compile: set[str] = set()

def get_imports(
self,
project: ProjectManager,
contract_filepaths: Iterable[Path],
) -> ImportMap:
def get_imports(self, project: ProjectManager, contract_filepaths: Iterable[Path]) -> ImportMap:
paths = list(contract_filepaths)
reset_view = None
if project.project_id not in self._projects:
self._projects[project.project_id] = ImportMap(project, paths)
else:
# Change the items we "view". Some (or all) may need to be added as well.
reset_view = self._projects[project.project_id]._request_view
self._projects[project.project_id]._request_view = paths

try:
import_map = self._get_imports(paths, project)
finally:
if reset_view is not None:
self._projects[project.project_id]._request_view = reset_view

return import_map
return self._get_imports(paths, project)

def _get_imports(self, paths: list[Path], project: ProjectManager) -> ImportMap:
import_map = self._projects[project.project_id]
import_map.paths = list({*import_map.paths, *paths})
for path in paths:
if path in import_map:
# Already handled.
Expand Down
Loading
Loading