Skip to content

Commit

Permalink
Fix formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoskelo committed Jul 22, 2024
1 parent 81b39bc commit d0e806b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 55 deletions.
3 changes: 1 addition & 2 deletions arraycontext/parameter_study/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@

import numpy as np

import pytato as pt
from pytato.array import Array

from arraycontext.context import ArrayContext
Expand Down Expand Up @@ -98,7 +97,7 @@ def pack_for_parameter_study(actx: ArrayContext,

orig_shape = args[0].shape
out = actx.np.stack(args, axis=len(args[0].shape))
outshape = tuple([*list(orig_shape), newshape])
outshape = *orig_shape, newshape

# if len(newshape) > 1:
# # Reshape the object
Expand Down
108 changes: 55 additions & 53 deletions arraycontext/parameter_study/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@
)


ArraysT = Tuple[Array, ...]


@dataclass(frozen=True)
class ParameterStudyAxisTag(UniqueTag):
"""
Expand All @@ -100,6 +103,9 @@ class ParameterStudyAxisTag(UniqueTag):
axis_size: int


StudiesT = Tuple[ParameterStudyAxisTag, ...]


class ExpansionMapper(CopyMapper):

def __init__(self, actual_input_shapes: Mapping[str, ShapeType],
Expand All @@ -108,8 +114,7 @@ def __init__(self, actual_input_shapes: Mapping[str, ShapeType],
self.actual_input_shapes = actual_input_shapes
self.actual_input_axes = actual_input_axes

def single_predecessor_updates(self,
curr_expr: Array,
def single_predecessor_updates(self, curr_expr: Array,
new_expr: Array) -> Tuple[ShapeType,
AxesT]:
# Initialize with something for the typing.
Expand All @@ -125,8 +130,8 @@ def single_predecessor_updates(self,
for _j, tag in enumerate(axis_tags):
# Should be relatively few tags on each axis $O(1)$.
if isinstance(tag, ParameterStudyAxisTag):
new_axes = new_axes + (new_expr.axes[i],)
shape_to_append = shape_to_append + (new_expr.shape[i],)
new_axes = *new_axes, new_expr.axes[i],
shape_to_append = *shape_to_append, new_expr.shape[i],
if already_added:
raise ValueError("An individual axis may only be " +
"tagged with one ParameterStudyAxisTag.")
Expand All @@ -137,8 +142,7 @@ def single_predecessor_updates(self,

def map_roll(self, expr: Roll) -> Array:
new_array = self.rec(expr.array)
_, new_axes = self.single_predecessor_updates(expr.array,
new_array)
_, new_axes = self.single_predecessor_updates(expr.array, new_array)
return Roll(array=new_array,
shift=expr.shift,
axis=expr.axis,
Expand All @@ -148,8 +152,7 @@ def map_roll(self, expr: Roll) -> Array:

def map_axis_permutation(self, expr: AxisPermutation) -> Array:
new_array = self.rec(expr.array)
postpend_shape, new_axes = self.single_predecessor_updates(expr.array,
new_array)
postpend_shape, new_axes = self.single_predecessor_updates(expr.array, new_array)
# Include the axes we are adding to the system.
axis_permute = expr.axis_permutation + tuple([i + len(expr.axis_permutation)
for i in range(len(postpend_shape))])
Expand All @@ -161,9 +164,9 @@ def map_axis_permutation(self, expr: AxisPermutation) -> Array:
non_equality_tags=expr.non_equality_tags)

def _map_index_base(self, expr: IndexBase) -> Array:
breakpoint()
new_array = self.rec(expr.array)
_, new_axes = self.single_predecessor_updates(expr.array,
new_array)
_, new_axes = self.single_predecessor_updates(expr.array, new_array)
return type(expr)(new_array,
indices=self.rec_idx_or_size_tuple(expr.indices),
# May need to modify indices
Expand All @@ -173,10 +176,9 @@ def _map_index_base(self, expr: IndexBase) -> Array:

def map_reshape(self, expr: Reshape) -> Array:
new_array = self.rec(expr.array)
postpend_shape, new_axes = self.single_predecessor_updates(expr.array,
new_array)
postpend_shape, new_axes = self.single_predecessor_updates(expr.array, new_array)
return Reshape(new_array,
newshape=self.rec_idx_or_size_tuple(expr.newshape +
newshape=self.rec_idx_or_size_tuple(expr.newshape + \
postpend_shape),
order=expr.order,
axes=expr.axes + new_axes,
Expand All @@ -202,13 +204,14 @@ def map_placeholder(self, expr: Placeholder) -> Array:
# {{{ Operations with multiple predecessors.

def _studies_from_multiple_pred(self,
new_arrays: Tuple[Array, ...]) -> Tuple[AxesT,
Set[ParameterStudyAxisTag],
Dict[Array, Tuple[ParameterStudyAxisTag, ...]]]:
new_arrays: ArraysT) -> Tuple[AxesT,
Set[ParameterStudyAxisTag],
Dict[Array,
StudiesT]]:

new_axes_for_end: Tuple[Axis, ...] = ()
new_axes_for_end: AxesT = ()
cur_studies: Set[ParameterStudyAxisTag] = set()
studies_by_array: Dict[Array, Tuple[ParameterStudyAxisTag, ...]] = {}
studies_by_array: Dict[Array, StudiesT] = {}

for _ind, array in enumerate(new_arrays):
for axis in array.axes:
Expand All @@ -224,7 +227,7 @@ def _studies_from_multiple_pred(self,

if axis_tags[0] not in cur_studies:
cur_studies.add(axis_tags[0])
new_axes_for_end = new_axes_for_end + (axis,)
new_axes_for_end = *new_axes_for_end, axis

return new_axes_for_end, cur_studies, studies_by_array

Expand All @@ -246,27 +249,27 @@ def map_concatenate(self, expr: Concatenate) -> Array:
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)

def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Tuple[Array, ...],
AxesT]:
def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[ArraysT,
AxesT]:

sing_inst_in_shape = expr.arrays[0].shape
one_inst_in_shape = expr.arrays[0].shape
new_arrays = tuple(self.rec(arr) for arr in expr.arrays)

_, cur_studies, studies_by_array = self._studies_from_multiple_pred(new_arrays)

study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {}

new_shape_of_predecessors = sing_inst_in_shape
new_shape_of_predecessors = one_inst_in_shape
new_axes = expr.axes

for study in cur_studies:
if isinstance(study, ParameterStudyAxisTag):
# Just defensive programming
# The active studies are added to the end of the bindings.
study_to_axis_number[study] = len(new_shape_of_predecessors)
new_shape_of_predecessors = new_shape_of_predecessors + \
new_shape_of_predecessors = *new_shape_of_predecessors, \
(study.axis_size,)
new_axes = new_axes + (Axis(tags=frozenset((study,))),)
new_axes = *new_axes, Axis(tags=frozenset((study,))),
# This assumes that the axis only has 1 tag,
# because there should be no dependence across instances.

Expand All @@ -275,41 +278,40 @@ def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Tuple[
# Now we need to update the expressions.
# Now that we have the appropriate shape,
# we need to update the input arrays to match.

cp_map = CopyMapper()
corrected_new_arrays: Tuple[Array, ...] = ()
corrected_new_arrays: ArraysT = ()
for _, array in enumerate(new_arrays):
tmp = cp_map(array) # Get a copy of the array.
if len(array.axes) < len(new_axes):
# We need to grow the array to the new size.
for study in cur_studies:
if study not in studies_by_array[array]:
build: Tuple[Array, ...] = tuple([cp_map(tmp) for
build: ArraysT = tuple([cp_map(tmp) for
_ in range(study.axis_size)])
tmp = Stack(arrays=build, axis=len(tmp.axes),
axes=tmp.axes +
(Axis(tags=frozenset((study,))),),
axes=(*tmp.axes, Axis(tags=frozenset((study,)))),
tags=tmp.tags,
non_equality_tags=tmp.non_equality_tags)
elif len(array.axes) > len(new_axes):
raise ValueError("Input array is too big. " + \
f"Expected at most: {len(new_axes)} " + \
raise ValueError("Input array is too big. " +
f"Expected at most: {len(new_axes)} " +
f"Found: {len(array.axes)} axes.")

# Now we need to correct to the appropriate shape with an axis permutation.
# These are known to be in the right place.
permute: Tuple[int, ...] = tuple([i for i in range(len(sing_inst_in_shape))])
permute: Tuple[int, ...] = tuple([i for i in range(len(one_inst_in_shape))])

for _, axis in enumerate(tmp.axes):
axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag))
if axis_tags:
assert len(axis_tags) == 1
permute = permute + (study_to_axis_number[axis_tags[0]],)
permute = *permute, study_to_axis_number[axis_tags[0]],
assert len(permute) == len(new_shape_of_predecessors)
corrected_new_arrays = corrected_new_arrays + \
(AxisPermutation(tmp, permute, tags=tmp.tags,
corrected_new_arrays = *corrected_new_arrays, \
AxisPermutation(tmp, permute, tags=tmp.tags,
axes=tmp.axes,
non_equality_tags=tmp.non_equality_tags),)
non_equality_tags=tmp.non_equality_tags),

return corrected_new_arrays, new_axes

Expand All @@ -322,18 +324,18 @@ def map_index_lambda(self, expr: IndexLambda) -> Array:
# Determine the new parameter studies that are being conducted.
from pytools import unique

all_axis_tags: Tuple[ParameterStudyAxisTag, ...] = ()
studies_by_variable: Dict[str, Dict[UniqueTag, bool]] = {}
all_axis_tags: StudiesT = ()
varname_to_studies: Dict[str, Dict[UniqueTag, bool]] = {}
for name, bnd in sorted(new_bindings.items()):
axis_tags_for_bnd: Set[Tag] = set()
studies_by_variable[name] = {}
varname_to_studies[name] = {}
for i in range(len(bnd.axes)):
axis_tags_for_bnd = axis_tags_for_bnd.union(bnd.axes[i].tags_of_type(ParameterStudyAxisTag))
for tag in axis_tags_for_bnd:
if isinstance(tag, ParameterStudyAxisTag):
# Defense
studies_by_variable[name][tag] = True
all_axis_tags = all_axis_tags + (tag,)
varname_to_studies[name][tag] = True
all_axis_tags = *all_axis_tags, tag,

cur_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags))
study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {}
Expand All @@ -346,13 +348,13 @@ def map_index_lambda(self, expr: IndexLambda) -> Array:
# Just defensive programming
# The active studies are added to the end of the bindings.
study_to_axis_number[study] = len(new_shape)
new_shape = new_shape + (study.axis_size,)
new_axes = new_axes + (Axis(tags=frozenset((study,))),)
new_shape = *new_shape, study.axis_size,
new_axes = *new_axes, Axis(tags=frozenset((study,))),
# This assumes that the axis only has 1 tag,
# because there should be no dependence across instances.

# Now we need to update the expressions.
scalar_expr = ParamAxisExpander()(expr.expr, studies_by_variable,
scalar_expr = ParamAxisExpander()(expr.expr, varname_to_studies,
study_to_axis_number)

return IndexLambda(expr=scalar_expr,
Expand Down Expand Up @@ -380,7 +382,7 @@ def map_einsum(self, expr: Einsum) -> Array:
# Just defensive programming
# The active studies are added to the end.
study_to_axis_number[study] = len(new_shape)
new_shape = new_shape + (study.axis_size,)
new_shape = *new_shape, study.axis_size,

for ind, array in enumerate(new_arrays):
for _, axis in enumerate(array.axes):
Expand All @@ -403,25 +405,25 @@ def map_einsum(self, expr: Einsum) -> Array:
class ParamAxisExpander(IdentityMapper):

def map_subscript(self, expr: prim.Subscript,
studies_by_variable: Mapping[str,
Mapping[ParameterStudyAxisTag, bool]],
varname_to_studies: Mapping[str,
Mapping[ParameterStudyAxisTag, bool]],
study_to_axis_number: Mapping[ParameterStudyAxisTag, int]):
# We know that we are not changing the variable that we are indexing into.
# This is stored in the aggregate member of the class Subscript.

# We only need to modify the indexing which is stored in the index member.
name = expr.aggregate.name
if name in studies_by_variable.keys():
if name in varname_to_studies.keys():
# These are the single instance information.
index = self.rec(expr.index, studies_by_variable,
index = self.rec(expr.index, varname_to_studies,
study_to_axis_number)

new_vars: Tuple[prim.Variable, ...] = ()

for key, num in sorted(study_to_axis_number.items(),
key=lambda item: item[1]):
if key in studies_by_variable[name]:
new_vars = new_vars + (prim.Variable(f"_{num}"),)
if key in varname_to_studies[name]:
new_vars = *new_vars, prim.Variable(f"_{num}"),

if isinstance(index, tuple):
index = index + new_vars
Expand Down Expand Up @@ -559,7 +561,7 @@ def _cut_if_in_param_study(name, arg) -> Array:
for i in range(ndim):
axis_tags = arg.axes[i].tags_of_type(ParameterStudyAxisTag)
if not axis_tags:
update_axes = update_axes + (arg.axes[i],)
update_axes = *update_axes, arg.axes[i],
newshape.append(arg.shape[i])
# remove the first one that was placed there for typing.
update_axes = update_axes[1:]
Expand Down

0 comments on commit d0e806b

Please sign in to comment.