Skip to content

Commit

Permalink
Add in asserts to confirm that multiple single instance programs in s…
Browse files Browse the repository at this point in the history
…equence result in the same numerical values as the one multiple instance execution.
  • Loading branch information
nkoskelo committed Aug 8, 2024
1 parent 30ee1e8 commit 65a9e59
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 44 deletions.
35 changes: 13 additions & 22 deletions examples/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,7 @@ def test_one_time_step_advection():

base_shape = np.prod((15, 5))
x0 = actx.from_numpy(rng.random(base_shape))
x1 = actx.from_numpy(rng.random(base_shape))
x2 = actx.from_numpy(rng.random(base_shape))
x3 = actx.from_numpy(rng.random(base_shape))

speed_shape = (1,)
y0 = actx.from_numpy(rng.random(speed_shape))
y1 = actx.from_numpy(rng.random(speed_shape))
y2 = actx.from_numpy(rng.random(speed_shape))
y3 = actx.from_numpy(rng.random(speed_shape))

ht = 0.0001
hx = 0.005
Expand All @@ -52,24 +44,23 @@ def rhs(fields, wave_speed):
return fields + wave_speed * (-1) * (ht / (2 * hx)) * \
(fields[kp1] - fields[km1])

pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3)
breakpoint()
assert pack_x.shape == (75, 4)

pack_y = pack_for_parameter_study(actx, ParamStudy1, (4,), y0, y1, y2, y3)
breakpoint()
assert pack_y.shape == (1, 4)
wave_speeds = [actx.from_numpy(np.random.random(1)) for _ in range(255)]
print(type(wave_speeds[0]))
packed_speeds = pack_for_parameter_study(actx, ParamStudy1, *wave_speeds)

compiled_rhs = actx.compile(rhs)
breakpoint()

output = compiled_rhs(pack_x, pack_y)
breakpoint()
assert output.shape(75, 4)
output = compiled_rhs(x0, packed_speeds)
output = actx.freeze(output)

expanded_output = actx.to_numpy(output).T

# Now for all the single values.
for idx in range(len(wave_speeds)):
out = compiled_rhs(x0, wave_speeds[idx])
out = actx.freeze(out)
assert np.allclose(expanded_output[idx], actx.to_numpy(out))

output_x = unpack_parameter_study(output, ParamStudy1)
assert len(output_x) == 1 # Only 1 study associated with this variable.
assert len(output_x[0]) == 4 # 4 inputs for the parameter study.

print("All checks passed")

Expand Down
69 changes: 47 additions & 22 deletions examples/parameter_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,11 @@
y3 = actx.from_numpy(rng.random(base_shape))


# Eq: z = x + y
# Eq: z = x @ y.T
# Assumptions: x and y are undergoing independent parameter studies.
# x and y are matrices such that x @ y.T works in the single instance case.
def rhs(param1, param2):
import pytato as pt
return pt.matmul(param1, param2.T)
return pt.stack([param1[0], param2[10]], axis=0)
return param1[0] + param2[10]

return param1 @ param2.T

@dataclass(frozen=True)
class ParameterStudyForX(ParameterStudyAxisTag):
Expand All @@ -50,26 +47,54 @@ class ParameterStudyForY(ParameterStudyAxisTag):

# Pack a parameter study of 3 instances for x and and 4 instances for y.


packx = pack_for_parameter_study(actx, ParameterStudyForX, (3,), x, x1, x2)
packy = pack_for_parameter_study(actx, ParameterStudyForY, (4,), y, y1, y2, y3)
packx = pack_for_parameter_study(actx, ParameterStudyForX, x, x1, x2)
packy = pack_for_parameter_study(actx, ParameterStudyForY, y, y1, y2, y3)

compiled_rhs = actx.compile(rhs) # Build the function caller

# Builds a trace for a single instance of evaluating the RHS and
# then converts it to a program which takes our multiple instances of `x` and `y`.
output = compiled_rhs(packx, packy)
output_2 = compiled_rhs(x, y)
breakpoint()

assert output.shape == (15, 5, 3, 4) # Distinct parameter studies.

output_x = unpack_parameter_study(output, ParameterStudyForX)
output_y = unpack_parameter_study(output, ParameterStudyForY)
assert len(output_x) == 1 # Number of parameter studies involving "x"
assert len(output_x[0]) == 3 # Number of inputs in the 0th parameter study
# All outputs across every other parameter study.
assert output_x[0][0].shape == (15, 5, 4)
assert len(output_y) == 1
assert len(output_y[0]) == 4
assert output_y[0][0].shape == (15, 5, 3)

numpy_output = actx.to_numpy(output)

assert numpy_output.shape == (15, 15, 3, 4)

out = actx.to_numpy(compiled_rhs(x, y))
assert np.allclose(numpy_output[..., 0, 0], out)

out = actx.to_numpy(compiled_rhs(x, y1))
assert np.allclose(numpy_output[..., 0, 1], out)

out = actx.to_numpy(compiled_rhs(x, y2))
assert np.allclose(numpy_output[..., 0, 2], out)

out = actx.to_numpy(compiled_rhs(x, y3))
assert np.allclose(numpy_output[..., 0, 3], out)

out = actx.to_numpy(compiled_rhs(x1, y))
assert np.allclose(numpy_output[..., 1, 0], out)

out = actx.to_numpy(compiled_rhs(x1, y1))
assert np.allclose(numpy_output[..., 1, 1], out)

out = actx.to_numpy(compiled_rhs(x1, y2))
assert np.allclose(numpy_output[..., 1, 2], out)

out = actx.to_numpy(compiled_rhs(x1, y3))
assert np.allclose(numpy_output[..., 1, 3], out)

out = actx.to_numpy(compiled_rhs(x2, y))
assert np.allclose(numpy_output[..., 2, 0], out)

out = actx.to_numpy(compiled_rhs(x2, y1))
assert np.allclose(numpy_output[..., 2, 1], out)

out = actx.to_numpy(compiled_rhs(x2, y2))
assert np.allclose(numpy_output[..., 2, 2], out)

out = actx.to_numpy(compiled_rhs(x2, y3))
assert np.allclose(numpy_output[..., 2, 3], out)

print("All tests passed!")

0 comments on commit 65a9e59

Please sign in to comment.