Skip to content

Commit

Permalink
Correct most of the type annotations.
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoskelo committed Jul 16, 2024
1 parent 804ed42 commit 607db79
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 57 deletions.
35 changes: 20 additions & 15 deletions arraycontext/parameter_study/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,15 @@
Union,
Sequence,
List,
Iterable,
Mapping,
)

import numpy as np
import pytato as pt

from pytato.array import Array

from pytools import memoize_method
from pytools.tag import Tag, ToTagSetConvertible, normalize_tags, UniqueTag

Expand All @@ -74,7 +78,7 @@

from arraycontext.container import ArrayContainer, is_array_container_type

from arraycontext.context import ArrayT, ArrayContext
from arraycontext.context import ArrayContext
from arraycontext.metadata import NameHint
from arraycontext import PytatoPyOpenCLArrayContext
from arraycontext.impl.pytato.compile import (LazilyPyOpenCLCompilingFunctionCaller,
Expand All @@ -99,9 +103,10 @@
logger = logging.getLogger(__name__)


def pack_for_parameter_study(actx: ArrayContext, study_name_tag: ParameterStudyAxisTag,
def pack_for_parameter_study(actx: ArrayContext,
study_name_tag_type: Type[ParameterStudyAxisTag],
newshape: Tuple[int, ...],
*args: ArrayT) -> ArrayT:
*args: Array) -> Array:
"""
Args is a list of variable names and the realized input data that needs
to be packed for a parameter study or uncertainty quantification.
Expand All @@ -115,42 +120,42 @@ def pack_for_parameter_study(actx: ArrayContext, study_name_tag: ParameterStudyA
assert len(args) == np.prod(newshape)

orig_shape = args[0].shape
out = actx.np.stack(args, axis=args[0].ndim)
out = actx.np.stack(args, axis=len(args[0].shape))
outshape = tuple(list(orig_shape) + [newshape] )

#if len(newshape) > 1:
# # Reshape the object
# out = out.reshape(outshape)

for i in range(len(orig_shape), len(outshape)):
out = out.with_tagged_axis(i, [study_name_tag(i - len(orig_shape), newshape[i-len(orig_shape)])])
out = out.with_tagged_axis(i, [study_name_tag_type(i - len(orig_shape), newshape[i-len(orig_shape)])])
return out


def unpack_parameter_study(data: ArrayT,
study_name_tag: ParameterStudyAxisTag) -> Dict[int,
List[ArrayT]]:
def unpack_parameter_study(data: Array,
study_name_tag_type: Type[ParameterStudyAxisTag]) -> Mapping[int,
List[Array]]:
"""
Split the data array along the axes which vary according to a ParameterStudyAxisTag
whose name tag is an instance study_name_tag.
whose name tag is an instance study_name_tag_type.
output[i] corresponds to the values associated with the ith parameter study that
uses the variable name :arg: `study_name_tag`.
uses the variable name :arg: `study_name_tag_type`.
"""

ndim: int = len(data.axes)
out: Dict[int, List[ArrayT]] = {}
ndim: int = len(data.shape)
out: Dict[int, List[Array]] = {}

study_count = 0
for i in range(ndim):
axis_tags = data.axes[i].tags_of_type(study_name_tag)
axis_tags = data.axes[i].tags_of_type(study_name_tag_type)
if axis_tags:
# Now we need to split this data.
breakpoint()
for j in range(data.shape[i]):
tmp: List[slice] = [slice(None)] * ndim
tmp: List[Any] = [slice(None)] * ndim
tmp[i] = j
the_slice: Tuple[slice] = tuple(tmp)
the_slice = tuple(tmp)
# Needs to be a tuple of slices not list of slices.
if study_count in out.keys():
out[study_count].append(data[the_slice])
Expand Down
97 changes: 55 additions & 42 deletions arraycontext/parameter_study/transform.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
"""
.. currentmodule:: arraycontext
Expand Down Expand Up @@ -46,6 +47,7 @@
Sequence,
List,
Mapping,
Set,
)

import numpy as np
Expand Down Expand Up @@ -84,7 +86,7 @@
DataWrapper, SizeParam, DictOfNamedArrays, AbstractResultWithNamedArrays,
Reshape, Concatenate, NamedArray, IndexRemappingBase, Einsum,
InputArgumentBase, AdvancedIndexInNoncontiguousAxes, IndexBase, DataInterface,
Axis)
Axis, ShapeType)

from pytato.utils import broadcast_binary_op

Expand Down Expand Up @@ -115,23 +117,30 @@ def __init__(self, actual_input_shapes: Mapping[str, Tuple[int,...]],


def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: Array,
new_expr: Array) -> Tuple[Optional[Tuple[int]],
Optional[Tuple[Axis]]]:
shape_to_prepend: Tuple[int] = tuple([])
new_axes: Tuple[Axis] = tuple([])
new_expr: Array) -> Tuple[ShapeType,
Tuple[Axis,...]]:
# Initialize with something for the typing.
shape_to_append: ShapeType= (-1,)
new_axes: Tuple[Axis,...] = (Axis(tags=frozenset()),)
if curr_expr.shape == new_expr.shape:
return shape_to_prepend, new_axes
return shape_to_append, new_axes

# Now we may need to change.
changed = False
for i in range(len(new_expr.axes)):
axis_tags = list(new_expr.axes[i].tags)
already_added = False
for j, tag in enumerate(axis_tags):
# Should be relatively few tags on each axis $O(1)$.
if isinstance(tag, ParameterStudyAxisTag):
new_axes = new_axes + (new_expr.axes[i],)
shape_to_prepend = shape_to_prepend + (new_expr.shape[i],)
return shape_to_prepend, new_axes
shape_to_append = shape_to_append + (new_expr.shape[i],)
if already_added:
raise ValueError("An individual axis may only be tagged with one ParameterStudyAxisTag.")
already_added = True

# Remove initialized extraneous data
return shape_to_append[1:], new_axes[1:]


def map_stack(self, expr: Stack) -> Array:
Expand All @@ -143,50 +152,49 @@ def map_concatenate(self, expr: Concatenate) -> Array:

def map_roll(self, expr: Roll) -> Array:
new_array = self.rec(expr.array)
prepend_shape, new_axes =self.does_single_predecessor_require_rewrite_of_this_operation(expr.array,
_, new_axes =self.does_single_predecessor_require_rewrite_of_this_operation(expr.array,
new_array)
return Roll(array=new_array,
shift=expr.shift,
axis=expr.axis + len(new_axes),
axes=new_axes + expr.axes,
axis=expr.axis,
axes=expr.axes + new_axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)

def map_axis_permutation(self, expr: AxisPermutation) -> Array:
new_array = self.rec(expr.array)
prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array,
postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array,
new_array)
axis_permute = tuple([expr.axis_permutation[i] + len(prepend_shape) for i
in range(len(expr.axis_permutation))])
# Include the axes we are adding to the system.
axis_permute = tuple([i for i in range(len(prepend_shape))]) + axis_permute
axis_permute = expr.axis_permutation + tuple([i + len(expr.axis_permutation)
for i in range(len(postpend_shape))])


return AxisPermutation(array=new_array,
axis_permutation=axis_permute,
axes=new_axes + expr.axes,
axes=expr.axes + new_axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)

def _map_index_base(self, expr: IndexBase) -> Array:
new_array = self.rec(expr.array)
prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array,
postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array,
new_array)
return type(expr)(new_array,
indices=self.rec_idx_or_size_tuple(expr.indices),
# May need to modify indices
axes=new_axes + expr.axes,
axes=expr.axes + new_axes,
tags=expr.tags,
non_equality_tags = expr.non_equality_tags)

def map_reshape(self, expr: Reshape) -> Array:
new_array = self.rec(expr.array)
prepend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array,
postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array,
new_array)
return Reshape(new_array,
newshape = self.rec_idx_or_size_tuple(prepend_shape + expr.newshape),
newshape = self.rec_idx_or_size_tuple(expr.newshape + postpend_shape),
order=expr.order,
axes=new_axes + expr.axes,
axes=expr.axes + new_axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)

Expand All @@ -198,7 +206,7 @@ def map_placeholder(self, expr: Placeholder) -> Array:
# We may need to update the size.
if expr.shape != self.actual_input_shapes[expr.name]:
correct_shape = self.actual_input_shapes[expr.name]
correct_axes = self.actual_input_axes[expr.name]
correct_axes = tuple(self.actual_input_axes[expr.name])
return Placeholder(name=expr.name,
shape=self.rec_idx_or_size_tuple(correct_shape),
dtype=expr.dtype,
Expand All @@ -208,30 +216,30 @@ def map_placeholder(self, expr: Placeholder) -> Array:

def map_index_lambda(self, expr: IndexLambda) -> Array:
# Update bindings first.
new_bindings: Mapping[str, Array] = { name: self.rec(bnd)
new_bindings: Dict[str, Array] = { name: self.rec(bnd)
for name, bnd in sorted(expr.bindings.items())}

# Determine the new parameter studies that are being conducted.
from pytools import unique
from pytools.obj_array import flat_obj_array

all_axis_tags: Set[Tag] = set()
studies_by_variable: Mapping[str, Mapping[Tag, bool]] = {}
all_axis_tags: Tuple[ParameterStudyAxisTag,...] = ()
studies_by_variable: Dict[str, Dict[UniqueTag, bool]] = {}
for name, bnd in sorted(new_bindings.items()):
axis_tags_for_bnd: Set[Tag] = set()
studies_by_variable[name] = {}
for i in range(len(bnd.axes)):
axis_tags_for_bnd = axis_tags_for_bnd.union(bnd.axes[i].tags_of_type(ParameterStudyAxisTag))
for tag in axis_tags_for_bnd:
studies_by_variable[name][tag] = 1
all_axis_tags = all_axis_tags.union(axis_tags_for_bnd)
if isinstance(tag, ParameterStudyAxisTag):
# Defense
studies_by_variable[name][tag] = True
all_axis_tags = all_axis_tags + (tag,)

# Freeze the set now.
all_axis_tags = frozenset(all_axis_tags)

active_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags))
axes: Optional[Tuple[Axis]] = tuple([])
study_to_axis_number: Mapping[ParameterStudyAxisTag, int] = {}
axes: Tuple[Axis, ...] = ()
study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {}

count = 0
new_shape = expr.shape
Expand All @@ -251,7 +259,7 @@ def map_index_lambda(self, expr: IndexLambda) -> Array:
scalar_expr = ParamAxisExpander()(expr.expr, studies_by_variable, study_to_axis_number)

return IndexLambda(expr=scalar_expr,
bindings=type(expr.bindings)(new_bindings),
bindings=immutabledict(new_bindings),
shape=new_shape,
var_to_reduction_descr=expr.var_to_reduction_descr,
dtype=expr.dtype,
Expand All @@ -274,7 +282,7 @@ def map_subscript(self, expr: prim.Subscript, studies_by_variable: Mapping[str,
index = self.rec(expr.index, studies_by_variable,
study_to_axis_number)

new_vars: Tuple[prim.Variable] = tuple([])
new_vars: Tuple[prim.Variable, ...] = ()

for key, val in sorted(study_to_axis_number.items(), key=lambda item: item[1]):
if key in studies_by_variable[name]:
Expand Down Expand Up @@ -379,19 +387,25 @@ def _as_dict_of_named_arrays(keys, ary):
rec_keyed_map_array_container(_as_dict_of_named_arrays,
output_template)

input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys()}
input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()}
#input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys() if hasattr(arg_id_to_descr[i], "shape")}
#input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()}
input_shapes = {}
input_axes = {}
for i in arg_id_to_descr.keys():
if hasattr(arg_id_to_descr[i], "shape"):
input_shapes[input_id_to_name_in_program[i]] = arg_id_to_descr[i].shape
input_axes[input_id_to_name_in_program[i]] = arg_id_to_arg[i].axes
myMapper = ExpansionMapper(input_shapes, input_axes) # Get the dependencies
breakpoint()

dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays)
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays)

breakpoint()
dict_of_named_arrays = myMapper(dict_of_named_arrays) # Update the arrays.

# Use the normal compiler now.

compiled_func = self._dag_to_compiled_func(dict_of_named_arrays,
compiled_func = self._dag_to_compiled_func(myMapper(pt_dict_of_named_arrays), # Update the arrays
#pt_dict_of_named_arrays,
input_id_to_name_in_program=input_id_to_name_in_program,
output_id_to_name_in_program=output_id_to_name_in_program,
output_template=output_template)
Expand All @@ -411,14 +425,13 @@ def _cut_if_in_param_study(name, arg) -> Array:
"""
ndim: int = len(arg.shape)
newshape = []
update_axes = []
update_axes: Tuple[Axis, ...] = (Axis(tags=frozenset()),)
for i in range(ndim):
axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag)
if not axis_tags:
update_axes.append(arg.axes[i])
update_axes = update_axes + (arg.axes[i],)
newshape.append(arg.shape[i])

update_axes = tuple(update_axes)
update_axes = update_axes[1:] # remove the first one that was placed there for typing.
update_tags: FrozenSet[Tag] = arg.tags
return pt.make_placeholder(name, newshape, arg.dtype, axes=update_axes,
tags=update_tags)
Expand Down

0 comments on commit 607db79

Please sign in to comment.