Skip to content

Commit

Permalink
modify how un/reshaping for tp space works for 1d
Browse files Browse the repository at this point in the history
  • Loading branch information
a-alveyblanc committed Dec 13, 2024
1 parent 8064696 commit 84bb977
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
11 changes: 9 additions & 2 deletions modepy/test/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def test_tensor_product_reshape(dim):

# {{{ test_tensor_product_vdm_dim_by_dim

@pytest.mark.parametrize("dim", [2, 3])
@pytest.mark.parametrize("dim", [1, 2, 3])
def test_tensor_product_vdm_dim_by_dim(dim):
"""Apply tensor product Vandermonde one dimension at a time, check that the
result matches what's obtained via the whole-space Vandermonde.
Expand All @@ -580,8 +580,15 @@ def test_tensor_product_vdm_dim_by_dim(dim):
x_r = reshape_array_for_tensor_product_space(space, x)
vdm_dimbydim_x_r = x_r

if dim == 1:
space_bases = (space,)
shape_bases = (shape,)
else:
space_bases = space.bases
shape_bases = shape.bases

for i, (subspace, subshape) in enumerate(
zip(space.bases, shape.bases, strict=True)):
zip(space_bases, shape_bases, strict=True)):
subnodes = mp.edge_clustered_nodes_for_space(subspace, subshape)
subbasis = mp.basis_for_space(subspace, subshape)
subvdm = mp.vandermonde(subbasis.functions, subnodes)
Expand Down
4 changes: 4 additions & 0 deletions modepy/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ def reshape_array_for_tensor_product_space(
function along a given dimension will be represented by variation
of array entries along the corresponding array axis.
"""
if space.spatial_dim == 1:
return ary

ndim = len(ary.shape)
if axis < 0:
Expand Down Expand Up @@ -518,6 +520,8 @@ def unreshape_array_for_tensor_product_space(
"""Undoes the effect of :func:`reshape_array_for_tensor_product_space`,
given the same *space* and *axis*.
"""
if space.spatial_dim == 1:
return ary

n_tp_axes = len(space.bases)
naxes = len(ary.shape) - n_tp_axes + 1
Expand Down

0 comments on commit 84bb977

Please sign in to comment.