Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update gt4py on 2022.11.07 #375

Merged
merged 60 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
a283712
Update gt4py on 2022.11.07
jdahm Nov 7, 2022
c8a3236
Regenerate constraints.txt
jdahm Nov 8, 2022
0902b3a
Fixes
jdahm Nov 10, 2022
a4cb91e
Fix type hints
jdahm Nov 10, 2022
a01ccaf
Fix in thresholds.py
jdahm Nov 10, 2022
a93e64a
_interpolate_origin -> _translate_origin
jdahm Nov 10, 2022
cf0e901
Fix typing
jdahm Nov 10, 2022
baf9b8f
Fix spelling
jdahm Nov 10, 2022
8f9a48f
Finish fixing pace.util tests
jdahm Nov 10, 2022
ed7f4fe
Add from_array to StorageNumpy
jdahm Nov 10, 2022
2eddc20
Fix __descriptor__?
jdahm Nov 10, 2022
e6969c4
Merge branch 'main' into update-gt4py
jdahm Nov 10, 2022
374e3a4
Change split_cartesian_into_storages
jdahm Nov 10, 2022
884a8a7
Fix storage references in translate tests
jdahm Nov 10, 2022
b92b37c
Fix data handling in translate tests
jdahm Nov 10, 2022
f0bf465
Check for both numpy and cupy ndarray types in physics translate code
jdahm Nov 10, 2022
b29d1e1
Fix flake8 error
jdahm Nov 10, 2022
22b71b8
Do not raise exception on cupy not installed
jdahm Nov 10, 2022
a23ad48
Put back conditional
jdahm Nov 11, 2022
a1c554f
Handle dict input in get_data
jdahm Nov 11, 2022
4c54f92
Address feedback
jdahm Nov 14, 2022
476a733
Check for optimal layout
jdahm Nov 14, 2022
d7d1494
Reset from_array
jdahm Nov 14, 2022
f437d88
Check arg.dims for Quantity
jdahm Nov 14, 2022
fc8417f
Fix padding issue in translate tests
jdahm Nov 14, 2022
7fc291b
Switch to fixing in place
jdahm Nov 15, 2022
a26e354
Try returning a dace descriptor
jdahm Nov 17, 2022
bb2618a
Return typeclass
jdahm Nov 17, 2022
9f5c42a
Lint
jdahm Nov 17, 2022
fba75a8
Add storage property to __descriptor__
jdahm Nov 18, 2022
af4edd0
Use dace.data.create_datadescriptor
jdahm Nov 21, 2022
fc8979c
Add return
jdahm Nov 26, 2022
5d47cec
Merge branch 'main' into update-gt4py
jdahm Dec 1, 2022
69a4643
Update gt4py
jdahm Dec 2, 2022
247ef30
Use gt4py.storage.dace_descriptor
jdahm Dec 2, 2022
921e4ae
Update attrs
jdahm Dec 2, 2022
cdc5990
Pass backend to quantity in QuantityFactory (when available)
FlorianDeconinck Dec 6, 2022
d7349c9
lint
FlorianDeconinck Dec 6, 2022
a0a9880
Don't use hasattr
FlorianDeconinck Dec 6, 2022
907b0e1
Try using dace descriptor since we use `cupy` objects
FlorianDeconinck Dec 8, 2022
17bd825
Cleanup of the quantity descriptor
FlorianDeconinck Dec 8, 2022
f08be00
lint
FlorianDeconinck Dec 8, 2022
e265bc9
Add back dace optional dependency
jdahm Dec 8, 2022
733fde8
Merge branch 'main' into update-gt4py
FlorianDeconinck Dec 12, 2022
e5b076e
Move to go31 project except buildenv
jdahm Dec 13, 2022
44b996f
Set a few back
jdahm Dec 13, 2022
8596aae
Point to correct buildenv version
jdahm Dec 13, 2022
152563f
Remove unused _set_device_modified, deleted in new GT4Py
FlorianDeconinck Dec 14, 2022
4495701
Merge branch 'move-ci-project' into update-gt4py-go31
jdahm Dec 14, 2022
7fbd87b
Merge branch 'update-gt4py' into update-gt4py-go31
jdahm Dec 14, 2022
523be60
Add `skip_test` option
FlorianDeconinck Dec 19, 2022
36f1da4
Merge branch 'move-ci-project' into update-gt4py-go31
FlorianDeconinck Dec 20, 2022
7902dd4
Correct no storage issues in test harness
FlorianDeconinck Dec 20, 2022
cab807d
Fix device_sync in translate test
FlorianDeconinck Dec 20, 2022
4b2fbc7
Fix data call in physics parallel test
FlorianDeconinck Dec 21, 2022
8fdc476
Merge branch 'update-gt4py-go31' into update-gt4py
FlorianDeconinck Dec 21, 2022
ee3549b
Transparent device-copy for physics parallel test
FlorianDeconinck Dec 21, 2022
b4e8489
Fillz test transparent device copy
FlorianDeconinck Dec 21, 2022
e26d50a
typo
FlorianDeconinck Dec 21, 2022
237e819
Actuall fillz transalte fix
FlorianDeconinck Dec 21, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://gitlab.com/pycqa/flake8
- repo: https://github.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8
Expand Down
3 changes: 2 additions & 1 deletion constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ dace==0.14
# -r requirements_dev.txt
# pace-dsl
# pace-dsl (dsl/setup.py)
# pace-util
dacite==1.6.0
# via
# fv3config
Expand Down Expand Up @@ -541,7 +542,7 @@ traitlets==5.5.0
# nbformat
typed-ast==1.4.3
# via mypy
typing-extensions==3.10.0.0
typing-extensions==4.3.0
# via
# aiohttp
# black
Expand Down
6 changes: 3 additions & 3 deletions driver/pace/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,9 @@ def _critical_path_step_all(
self.end_of_step_update(
dycore_state=self.state.dycore_state,
phy_state=self.state.physics_state,
u_dt=self.state.tendency_state.u_dt.storage,
v_dt=self.state.tendency_state.v_dt.storage,
pt_dt=self.state.tendency_state.pt_dt.storage,
u_dt=self.state.tendency_state.u_dt,
v_dt=self.state.tendency_state.v_dt,
pt_dt=self.state.tendency_state.pt_dt,
dt=float(dt),
)
self._end_of_step_actions(step)
Expand Down
36 changes: 24 additions & 12 deletions dsl/pace/dsl/dace/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,26 @@
from pace.util.mpi import MPI


try:
import cupy as cp
except ImportError:
cp = None


def dace_inhibitor(func: Callable):
"""Triggers callback generation wrapping `func` while doing DaCe parsing."""
return func


def _upload_to_device(host_data: List[Any]):
"""Make sure any data that are still a gt4py.storage gets uploaded to device"""
for data in host_data:
if isinstance(data, gt4py.storage.Storage):
data.host_to_device()
"""Make sure any ndarrays gets uploaded to the device

This will raise an assertion if cupy is not installed.
"""
assert cp is not None
for i, data in enumerate(host_data):
if isinstance(data, cp.ndarray):
host_data[i] = cp.asarray(data)


def _download_results_from_dace(
Expand All @@ -55,10 +65,11 @@ def _download_results_from_dace(
gt4py_results = None
if dace_result is not None:
for arg in args:
if isinstance(arg, gt4py.storage.Storage) and hasattr(
arg, "_set_device_modified"
):
arg._set_device_modified()
try:
if isinstance(arg, cp.ndarray):
arg._set_device_modified()
except AttributeError:
pass
if config.is_gpu_backend():
gt4py_results = [
gt4py.storage.from_array(
Expand Down Expand Up @@ -111,7 +122,8 @@ def _to_gpu(sdfg: dace.SDFG):

def _run_sdfg(daceprog: DaceProgram, config: DaceConfig, args, kwargs):
"""Execute a compiled SDFG - do not check for compilation"""
_upload_to_device(list(args) + list(kwargs.values()))
if config.is_gpu_backend():
_upload_to_device(list(args) + list(kwargs.values()))
res = daceprog(*args, **kwargs)
return _download_results_from_dace(config, res, list(args) + list(kwargs.values()))

Expand All @@ -129,15 +141,15 @@ def _build_sdfg(
if config.is_gpu_backend():
_to_gpu(sdfg)
make_transients_persistent(sdfg=sdfg, device=DaceDeviceType.GPU)

# Upload args to device
_upload_to_device(list(args) + list(kwargs.values()))
else:
for sd, _aname, arr in sdfg.arrays_recursive():
if arr.shape == (1,):
arr.storage = DaceStorageType.Register
make_transients_persistent(sdfg=sdfg, device=DaceDeviceType.CPU)

# Upload args to device
_upload_to_device(list(args) + list(kwargs.values()))

# Build non-constants & non-transients from the sdfg_kwargs
sdfg_kwargs = daceprog._create_sdfg_args(sdfg, args, kwargs)
for k in daceprog.constant_args:
Expand Down
84 changes: 44 additions & 40 deletions dsl/pace/dsl/gt4py_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import logging
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import gt4py.backend
import gt4py.storage as gt_storage
import gt4py
import numpy as np

from pace.dsl.typing import DTypes, Field, Float, FloatField
from pace.dsl.typing import DTypes, Field, Float


try:
Expand Down Expand Up @@ -50,6 +49,34 @@ def wrapper(*args, **kwargs) -> Any:
return inner


def _mask_to_dimensions(
mask: Tuple[bool, ...], shape: Sequence[int]
) -> List[Union[str, int]]:
assert len(mask) == 3
dimensions: List[Union[str, int]] = []
for i, axis in enumerate(("I", "J", "K")):
if mask[i]:
dimensions.append(axis)
offset = int(sum(mask))
dimensions.extend(shape[offset:])
return dimensions


def _translate_origin(origin: Sequence[int], mask: Tuple[bool, ...]) -> Sequence[int]:
if len(origin) == int(sum(mask)):
# Correct length. Assumedd to be correctly specified.
return origin

assert len(mask) == 3
final_origin: List[int] = []
for i, has_axis in enumerate(mask):
if has_axis:
final_origin.append(origin[i])

final_origin.extend(origin[len(mask) :])
return final_origin


def make_storage_data(
data: Field,
shape: Optional[Tuple[int, ...]] = None,
Expand Down Expand Up @@ -129,14 +156,12 @@ def make_storage_data(
else:
data = _make_storage_data_3d(data, shape, start, backend=backend)

storage = gt_storage.from_array(
data=data,
storage = gt4py.storage.from_array(
data,
dtype,
backend=backend,
default_origin=origin,
shape=shape,
dtype=dtype,
mask=mask,
managed_memory=managed_memory,
aligned_index=_translate_origin(origin, mask),
dimensions=_mask_to_dimensions(mask, data.shape),
)
return storage

Expand Down Expand Up @@ -264,13 +289,12 @@ def make_storage_from_shape(
mask = (False, False, True) # Assume 1D is a k-field
else:
mask = (n_dims * (True,)) + ((3 - n_dims) * (False,))
storage = gt_storage.zeros(
storage = gt4py.storage.zeros(
shape,
dtype,
backend=backend,
default_origin=origin,
shape=shape,
dtype=dtype,
mask=mask,
managed_memory=managed_memory,
aligned_index=_translate_origin(origin, mask),
dimensions=_mask_to_dimensions(mask, shape),
)
return storage

Expand Down Expand Up @@ -340,8 +364,6 @@ def k_split_run(func, data, k_indices, splitvars_values):


def asarray(array, to_type=np.ndarray, dtype=None, order=None):
if isinstance(array, gt_storage.storage.Storage):
array = array.data
if cp and (isinstance(array, list)):
if to_type is np.ndarray:
order = "F" if order is None else order
Expand Down Expand Up @@ -379,19 +401,15 @@ def is_gpu_backend(backend: str) -> bool:
def zeros(shape, dtype=Float, *, backend: str):
storage_type = cp.ndarray if is_gpu_backend(backend) else np.ndarray
xp = cp if cp and storage_type is cp.ndarray else np
return xp.zeros(shape)
return xp.zeros(shape, dtype=dtype)


def sum(array, axis=None, dtype=Float, out=None, keepdims=False):
if isinstance(array, gt_storage.storage.Storage):
array = array.data
xp = cp if cp and type(array) is cp.ndarray else np
return xp.sum(array, axis, dtype, out, keepdims)


def repeat(array, repeats, axis=None):
if isinstance(array, gt_storage.storage.Storage):
array = array.data
xp = cp if cp and type(array) is cp.ndarray else np
return xp.repeat(array, repeats, axis)

Expand All @@ -401,22 +419,16 @@ def index(array, key):


def moveaxis(array, source: int, destination: int):
jdahm marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(array, gt_storage.storage.Storage):
array = array.data
xp = cp if cp and type(array) is cp.ndarray else np
return xp.moveaxis(array, source, destination)


def tile(array, reps: Union[int, Tuple[int, ...]]):
if isinstance(array, gt_storage.storage.Storage):
array = array.data
xp = cp if cp and type(array) is cp.ndarray else np
return xp.tile(array, reps)


def squeeze(array, axis: Union[int, Tuple[int]] = None):
if isinstance(array, gt_storage.storage.Storage):
array = array.data
xp = cp if cp and type(array) is cp.ndarray else np
return xp.squeeze(array, axis)

Expand Down Expand Up @@ -444,17 +456,13 @@ def unique(
return_counts: bool = False,
axis: Union[int, Tuple[int]] = None,
):
if isinstance(array, gt_storage.storage.Storage):
array = array.data
xp = cp if cp and type(array) is cp.ndarray else np
return xp.unique(array, return_index, return_inverse, return_counts, axis)


def stack(tup, axis: int = 0, out=None):
array_tup = []
for array in tup:
if isinstance(array, gt_storage.storage.Storage):
array = array.data
array_tup.append(array)
xp = cp if cp and type(array_tup[0]) is cp.ndarray else np
return xp.stack(array_tup, axis, out)
Expand All @@ -465,7 +473,7 @@ def device_sync(backend: str) -> None:
cp.cuda.Device(0).synchronize()


def split_cartesian_into_storages(var: FloatField):
def split_cartesian_into_storages(var: np.ndarray) -> Sequence[np.ndarray]:
"""
Provided a storage of dims [X_DIM, Y_DIM, CARTESIAN_DIM]
or [X_INTERFACE_DIM, Y_INTERFACE_DIM, CARTESIAN_DIM]
Expand All @@ -475,10 +483,6 @@ def split_cartesian_into_storages(var: FloatField):
var_data = []
for cart in range(3):
var_data.append(
make_storage_data(
asarray(var.data, type(var.data))[:, :, cart],
var.data.shape[0:2],
backend=var.backend,
)
asarray(var, type(var))[:, :, cart],
)
return var_data
25 changes: 19 additions & 6 deletions dsl/pace/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import gt4py
import numpy as np
from gt4py import gtscript
from gt4py.storage.storage import Storage
from gtc.passes.oir_pipeline import DefaultPipeline, OirPipeline

import pace.util
Expand All @@ -32,13 +31,19 @@
from pace.util.mpi import MPI


try:
import cupy as cp
except ImportError:
cp = np


def report_difference(args, kwargs, args_copy, kwargs_copy, function_name, gt_id):
report_head = f"comparing against numpy for func {function_name}, gt_id {gt_id}:"
report_segments = []
for i, (arg, numpy_arg) in enumerate(zip(args, args_copy)):
if isinstance(arg, pace.util.Quantity):
arg = arg.storage
numpy_arg = numpy_arg.storage
arg = arg.data
numpy_arg = numpy_arg.data
if isinstance(arg, np.ndarray):
report_segments.append(report_diff(arg, numpy_arg, label=f"arg {i}"))
for name in kwargs:
Expand Down Expand Up @@ -429,7 +434,7 @@ def __call__(self, *args, **kwargs) -> None:
f"after calling {self._func_name}"
)

def _mark_cuda_fields_written(self, fields: Mapping[str, Storage]):
def _mark_cuda_fields_written(self, fields: Mapping[str, cp.ndarray]):
if self.stencil_config.is_gpu_backend:
for write_field in self._written_fields:
fields[write_field]._set_device_modified()
Expand Down Expand Up @@ -519,12 +524,20 @@ def closure_resolver(self, constant_args, given_args, parent_closure=None):
def _convert_quantities_to_storage(args, kwargs):
for i, arg in enumerate(args):
try:
args[i] = arg.storage
# Check that 'dims' is an attribute of arg. If so,
# this means it's a pace.util.Quantity, so we need
# to pull off the ndarray.
arg.dims
args[i] = arg.data
except AttributeError:
pass
for name, arg in kwargs.items():
try:
kwargs[name] = arg.storage
# Check that 'dims' is an attribute of arg. If so,
# this means it's a pace.util.Quantity, so we need
# to pull off the ndarray.
arg.dims
kwargs[name] = arg.data
except AttributeError:
pass

Expand Down
2 changes: 1 addition & 1 deletion external/gt4py
Submodule gt4py updated 106 files
7 changes: 2 additions & 5 deletions fv3core/examples/standalone/runfile/acoustics.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,10 @@ def get_state_from_input(
) -> Dict[str, SimpleNamespace]:
"""
Transforms the input data from the dictionary of strings
to arrays into a state we can pass in

Input is a dict of arrays. These are transformed into Storage arrays
useable in GT4Py
to arrays into a state we can pass in.

This will also take care of reshaping the arrays into same sized
fields as required by the acoustics
fields as required by the acoustics.
"""
driver_object = TranslateDynCore([grid], namelist, stencil_config)
driver_object._base.make_storage_data_input_vars(input_data)
Expand Down
2 changes: 1 addition & 1 deletion fv3core/pace/fv3core/initialization/dycore_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def init_zeros(cls, quantity_factory: pace.util.QuantityFactory):
if "dims" in _field.metadata.keys():
initial_storages[_field.name] = quantity_factory.zeros(
_field.metadata["dims"], _field.metadata["units"], dtype=float
).storage
).data
return cls.init_from_storages(
storages=initial_storages, sizer=quantity_factory.sizer
)
Expand Down
Loading