From 55226bbee8151cf38ff380f0aa79ef575776f079 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Wed, 10 Jan 2024 11:23:23 -0800 Subject: [PATCH 1/2] Avoid computing loss when weight is zero --- src/fmmax/vector.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/fmmax/vector.py b/src/fmmax/vector.py index 86fb3a1..473c27a 100644 --- a/src/fmmax/vector.py +++ b/src/fmmax/vector.py @@ -393,17 +393,28 @@ def _field_loss( shape=shape, axis=-2, ) - alignment_loss = _alignment_loss( - field, target_field, elementwise_alignment_loss_weight - ) + loss = _alignment_loss(field, target_field, elementwise_alignment_loss_weight) + + # Avoid calculating the fourier loss and smoothness loss if their weights are zero. + # On some platforms, including a smoothness loss can signifcantly slow the compile + # times, and so this optimization increases performance. + with jax.ensure_compile_time_eval(): + assert jnp.size(fourier_loss_weight) == 1 + assert jnp.size(smoothness_loss_weight) == 1 + use_fourier_loss = fourier_loss_weight > 0 + use_smoothness_loss = smoothness_loss_weight > 0 + + if use_fourier_loss: + loss += fourier_loss_weight * _fourier_loss( + fourier_field, expansion, primitive_lattice_vectors + ) - fourier_loss = _fourier_loss(fourier_field, expansion, primitive_lattice_vectors) - smoothness_loss = _smoothness_loss(field, primitive_lattice_vectors) - return ( - alignment_loss - + fourier_loss_weight * fourier_loss - + smoothness_loss_weight * smoothness_loss - ) + if use_smoothness_loss: + loss += smoothness_loss_weight * _smoothness_loss( + field, primitive_lattice_vectors + ) + + return loss def _alignment_loss( From 5fe7042613e5272fec1f1f40ad420ecf74e7e371 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Wed, 10 Jan 2024 11:38:52 -0800 Subject: [PATCH 2/2] Version updated from v0.5.1 to v0.5.2 --- .bumpversion.toml | 2 +- pyproject.toml | 2 +- src/fmmax/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.bumpversion.toml b/.bumpversion.toml index 67e6da7..5865569 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "v0.5.1" +current_version = "v0.5.2" commit = true commit_args = "--no-verify" tag = true diff --git a/pyproject.toml b/pyproject.toml index 99251b2..7ef1bcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "fmmax" -version = "v0.5.1" +version = "v0.5.2" description = "Fourier modal method with Jax" readme = "README.md" requires-python = ">=3.7" diff --git a/src/fmmax/__init__.py b/src/fmmax/__init__.py index 1a30bda..68225e5 100644 --- a/src/fmmax/__init__.py +++ b/src/fmmax/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -__version__ = "v0.5.1" +__version__ = "v0.5.2" from . import ( basis,