Skip to content

Commit

Permalink
Merge pull request #75 from facebookresearch/opt
Browse files Browse the repository at this point in the history
Avoid computing smoothness/fourier loss when weight is zero
  • Loading branch information
mfschubert authored Jan 10, 2024
2 parents 36a71f0 + 5fe7042 commit d461c4b
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "v0.5.1"
current_version = "v0.5.2"
commit = true
commit_args = "--no-verify"
tag = true
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]

name = "fmmax"
version = "v0.5.1"
version = "v0.5.2"
description = "Fourier modal method with Jax"
readme = "README.md"
requires-python = ">=3.7"
Expand Down
2 changes: 1 addition & 1 deletion src/fmmax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

__version__ = "v0.5.1"
__version__ = "v0.5.2"

from . import (
basis,
Expand Down
31 changes: 21 additions & 10 deletions src/fmmax/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,17 +393,28 @@ def _field_loss(
shape=shape,
axis=-2,
)
alignment_loss = _alignment_loss(
field, target_field, elementwise_alignment_loss_weight
)
loss = _alignment_loss(field, target_field, elementwise_alignment_loss_weight)

# Avoid calculating the fourier loss and smoothness loss if their weights are zero.
# On some platforms, including a smoothness loss can signifcantly slow the compile
# times, and so this optimization increases performance.
with jax.ensure_compile_time_eval():
assert jnp.size(fourier_loss_weight) == 1
assert jnp.size(smoothness_loss_weight) == 1
use_fourier_loss = fourier_loss_weight > 0
use_smoothness_loss = smoothness_loss_weight > 0

if use_fourier_loss:
loss += fourier_loss_weight * _fourier_loss(
fourier_field, expansion, primitive_lattice_vectors
)

fourier_loss = _fourier_loss(fourier_field, expansion, primitive_lattice_vectors)
smoothness_loss = _smoothness_loss(field, primitive_lattice_vectors)
return (
alignment_loss
+ fourier_loss_weight * fourier_loss
+ smoothness_loss_weight * smoothness_loss
)
if use_smoothness_loss:
loss += smoothness_loss_weight * _smoothness_loss(
field, primitive_lattice_vectors
)

return loss


def _alignment_loss(
Expand Down

0 comments on commit d461c4b

Please sign in to comment.