Skip to content

Commit

Permalink
Adding Sine and Cosine transform
Browse files Browse the repository at this point in the history
  • Loading branch information
tsunhopang committed Oct 18, 2024
1 parent 3fa2ccb commit 85702ae
Showing 1 changed file with 60 additions and 2 deletions.
62 changes: 60 additions & 2 deletions src/jimgw/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def transform(self, x: dict[str, Float]) -> tuple[dict[str, Float], Float]:
output_params = self.transform_func(transform_params)
jacobian = jax.jacfwd(self.transform_func)(transform_params)
jacobian = jnp.array(jax.tree.leaves(jacobian))
jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))))
jacobian = jnp.log(
jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))
)
jax.tree.map(
lambda key: x_copy.pop(key),
self.name_mapping[0],
Expand Down Expand Up @@ -124,7 +126,9 @@ def inverse(self, y: dict[str, Float]) -> tuple[dict[str, Float], Float]:
output_params = self.inverse_transform_func(transform_params)
jacobian = jax.jacfwd(self.inverse_transform_func)(transform_params)
jacobian = jnp.array(jax.tree.leaves(jacobian))
jacobian = jnp.log(jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim))))
jacobian = jnp.log(
jnp.absolute(jnp.linalg.det(jacobian.reshape(self.n_dim, self.n_dim)))
)
jax.tree.map(
lambda key: y_copy.pop(key),
self.name_mapping[1],
Expand Down Expand Up @@ -298,6 +302,60 @@ def __init__(
}


@jaxtyped(typechecker=typechecker)
class SineTransform(BijectiveTransform):
"""
Sine transformation
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
self.transform_func = lambda x: {
name_mapping[1][i]: jnp.sin(x[name_mapping[0][i]])
for i in range(len(name_mapping[0]))
}
self.inverse_transform_func = lambda x: {
name_mapping[0][i]: jnp.arcsin(x[name_mapping[1][i]])
for i in range(len(name_mapping[1]))
}


@jaxtyped(typechecker=typechecker)
class CosineTransform(BijectiveTransform):
"""
Cosine transformation
Parameters
----------
name_mapping : tuple[list[str], list[str]]
The name mapping between the input and output dictionary.
"""

def __init__(
self,
name_mapping: tuple[list[str], list[str]],
):
super().__init__(name_mapping)
self.transform_func = lambda x: {
name_mapping[1][i]: jnp.cos(x[name_mapping[0][i]])
for i in range(len(name_mapping[0]))
}
self.inverse_transform_func = lambda x: {
name_mapping[0][i]: jnp.arccos(x[name_mapping[1][i]])
for i in range(len(name_mapping[1]))
}


@jaxtyped(typechecker=typechecker)
class ArcSineTransform(BijectiveTransform):
"""
Expand Down

0 comments on commit 85702ae

Please sign in to comment.