diff --git a/axlearn/common/causal_lm.py b/axlearn/common/causal_lm.py index c1bb7dfb6..3dcdfdaa2 100644 --- a/axlearn/common/causal_lm.py +++ b/axlearn/common/causal_lm.py @@ -99,9 +99,11 @@ def forward( target_labels: Tensor = input_batch["target_labels"] target_num_bytes: Optional[Tensor] = input_batch.get("target_num_bytes") + live_targets: Optional[Tensor] = input_batch.get("live_targets") logits = predict_outputs["logits"] - live_targets = target_labels >= 0 + if live_targets is None: + live_targets = target_labels >= 0 num_targets = live_targets.sum() accuracy = ( jnp.equal(jnp.argmax(logits, axis=-1), target_labels) * live_targets @@ -226,9 +228,19 @@ class CompositeLossMetrics(BaseLossMetrics): @config_class class Config(BaseLossMetrics.Config): - """Configures CompositeLossMetrics.""" + """Configures CompositeLossMetrics. + + Attributes: + metrics: A mapping from child name to metrics config. + loss_weights: An optional mapping from child name to loss weight. + If None, all weights are considered 1. + flatten_metrics: Whether to flatten summaries and metrics from each child. If None, + defaults to True. + """ metrics: Required[dict[str, BaseLossMetrics.Config]] = REQUIRED + loss_weights: Optional[dict[str, float]] = None + flatten_metrics: Optional[bool] = None def __init__(self, cfg, *, parent): super().__init__(cfg, parent=parent) @@ -236,6 +248,8 @@ def __init__(self, cfg, *, parent): self._metrics: dict[str, BaseLossMetrics] = {} for name, child in cfg.metrics.items(): self._metrics[name] = self._add_child(name, child) + if cfg.loss_weights is not None and cfg.metrics.keys() != cfg.loss_weights.keys(): + raise ValueError(f"Expected {cfg.loss_weights.keys()=} to match {cfg.metrics.keys()}.") def forward( self, @@ -249,6 +263,7 @@ def forward( By default, losses are summed and metrics/summaries are flattened, raising if any keys conflict. """ + cfg: CompositeLossMetrics.Config = self.config loss = 0 metrics = {} @@ -258,12 +273,17 @@ def forward( predict_outputs=predict_outputs, module_outputs=module_outputs, ) + if cfg.loss_weights is not None: + child_loss *= cfg.loss_weights[name] loss = loss + child_loss ctx = self.get_invocation_context() - # Flatten summaries for backwards compatibility. - _update(ctx.output_collection.summaries, ctx.output_collection.summaries.pop(name)) - _update(metrics, child_metrics) + + if cfg.flatten_metrics is False: + _update(metrics, {name: child_metrics}) + else: + _update(ctx.output_collection.summaries, ctx.output_collection.summaries.pop(name)) + _update(metrics, child_metrics) return loss, metrics diff --git a/axlearn/common/causal_lm_test.py b/axlearn/common/causal_lm_test.py index fa240b72d..aeb4bc229 100644 --- a/axlearn/common/causal_lm_test.py +++ b/axlearn/common/causal_lm_test.py @@ -345,7 +345,7 @@ def forward(self, *args, **kwargs): self.add_summary(f"{self.name}_summary", 0) self.add_module_output(f"{self.name}_output", 0) self.add_state_update(f"{self.name}_state", 0) - return 0, {} + return 0, {f"{self.name}_output": 0} class DummyConflictModel(causal_lm.Model): def _metrics(self, *args, **kwargs): @@ -383,11 +383,16 @@ def _metrics(self, *args, **kwargs): self.add_state_update("parent_state", 1) return super()._metrics(*args, **kwargs) - def test_no_conflict(metrics_cfg: BaseLossMetrics.Config, expected: OutputCollection): - _, output_collection = forward(DummyModel.default_config(), metrics_cfg) - self.assertNestedEqual(output_collection.summaries, expected.summaries) - self.assertNestedEqual(output_collection.module_outputs, expected.module_outputs) - self.assertNestedEqual(output_collection.state_updates, expected.state_updates) + def test_no_conflict( + metrics_cfg: BaseLossMetrics.Config, + expected_oc: OutputCollection, + expected_metrics: dict, + ): + (_, metrics), output_collection = forward(DummyModel.default_config(), metrics_cfg) + self.assertNestedEqual(output_collection.summaries, expected_oc.summaries) + self.assertNestedEqual(output_collection.module_outputs, expected_oc.module_outputs) + self.assertNestedEqual(output_collection.state_updates, expected_oc.state_updates) + self.assertNestedEqual(metrics, expected_metrics) test_no_conflict( DummyMetrics.default_config(), @@ -396,16 +401,46 @@ def test_no_conflict(metrics_cfg: BaseLossMetrics.Config, expected: OutputCollec module_outputs={"parent_output": 1, "metrics": {"metrics_output": 0}}, state_updates={"parent_state": 1, "metrics": {"metrics_state": 0}}, ), - ) + {"metrics_output": 0}, + ) + for flatten_metrics in (None, True): + test_no_conflict( + causal_lm.CompositeLossMetrics.default_config().set( + metrics={ + "child1": DummyMetrics.default_config(), + "child2": DummyMetrics.default_config(), + }, + flatten_metrics=flatten_metrics, + ), + OutputCollection( + summaries={"parent_summary": 1, "child1_summary": 0, "child2_summary": 0}, + module_outputs={ + "parent_output": 1, + "metrics": {"child1": {"child1_output": 0}, "child2": {"child2_output": 0}}, + }, + state_updates={ + "parent_state": 1, + "metrics": {"child1": {"child1_state": 0}, "child2": {"child2_state": 0}}, + }, + ), + {"child1_output": 0, "child2_output": 0}, + ) + + # Test without flattening. test_no_conflict( causal_lm.CompositeLossMetrics.default_config().set( metrics={ "child1": DummyMetrics.default_config(), "child2": DummyMetrics.default_config(), - } + }, + flatten_metrics=False, ), OutputCollection( - summaries={"parent_summary": 1, "child1_summary": 0, "child2_summary": 0}, + summaries={ + "parent_summary": 1, + "child1": {"child1_summary": 0}, + "child2": {"child2_summary": 0}, + }, module_outputs={ "parent_output": 1, "metrics": {"child1": {"child1_output": 0}, "child2": {"child2_output": 0}}, @@ -415,6 +450,7 @@ def test_no_conflict(metrics_cfg: BaseLossMetrics.Config, expected: OutputCollec "metrics": {"child1": {"child1_state": 0}, "child2": {"child2_state": 0}}, }, ), + {"child1": {"child1_output": 0}, "child2": {"child2_output": 0}}, ) # TODO(markblee): Add a pytest marker for multi-device tests. @@ -477,6 +513,90 @@ def fn(x): ) +class CrossEntropyLossMetricsTest(TestCase): + """Tests CrossEntropyLossMetrics.""" + + def test_live_targets(self): + batch_size, seq_len, vocab_size = 3, 10, 10 + tgt_key, logit_key, live_tgt_key = jax.random.split(jax.random.PRNGKey(0), num=3) + target_labels = jax.random.randint( + tgt_key, shape=[batch_size, seq_len], minval=-1, maxval=vocab_size + ) + logits = jax.random.uniform(logit_key, shape=[*target_labels.shape, vocab_size]) + layer = ( + causal_lm.CrossEntropyLossMetrics.default_config() + .set(name="test") + .instantiate(parent=None) + ) + + # Make sure at least one masked target. + assert jnp.any(target_labels == -1), target_labels + + def forward(live_targets): + (loss, metrics), _ = functional( + layer, + prng_key=None, + state={}, + inputs=dict( + input_batch=dict(target_labels=target_labels, live_targets=live_targets), + predict_outputs=dict(logits=logits), + module_outputs={}, + ), + is_training=True, + ) + return loss, metrics + + # Test without live_targets. Should be equivalent to target_labels >= 0. + test_loss, metrics = forward(live_targets=None) + ref_loss, _ = cross_entropy(logits, target_labels, live_targets=target_labels >= 0) + self.assertAlmostEqual(test_loss, ref_loss) + self.assertEqual(metrics["num_targets"], (target_labels >= 0).sum()) + + # Test with live_targets. + live_targets = jax.random.randint( + live_tgt_key, shape=target_labels.shape, minval=0, maxval=2 + ) + test_loss, metrics = forward(live_targets=live_targets) + ref_loss, _ = cross_entropy(logits, target_labels, live_targets=live_targets) + self.assertAlmostEqual(test_loss, ref_loss) + self.assertEqual(metrics["num_targets"], live_targets.sum()) + + +class CompositeLossMetricsTest(TestCase): + """Tests CompositeLossMetrics.""" + + def test_loss_weights(self): + class DummyMetrics(BaseLossMetrics): + def forward(self, input_batch, **kwargs): + del kwargs + return input_batch[self.name], {} + + cfg = causal_lm.CompositeLossMetrics.default_config().set( + name="test", + metrics={ + "test0": DummyMetrics.default_config(), + "test1": DummyMetrics.default_config(), + }, + ) + + # Test mismatched keys. + with self.assertRaisesRegex(ValueError, "keys"): + cfg.set(loss_weights={"test0": 0.5}).instantiate(parent=None) + + metrics = cfg.set(loss_weights={"test0": 0.5, "test1": 1.0}).instantiate(parent=None) + + (loss, _), _ = functional( + metrics, + prng_key=jax.random.PRNGKey(123), + state={}, + inputs=dict( + input_batch={"test0": 1.23, "test1": 3.45}, predict_outputs={}, module_outputs={} + ), + is_training=True, + ) + self.assertAlmostEqual(loss, 1.23 * 0.5 + 3.45) + + class DummyFeedForwardWithAuxLoss(TransformerFeedForwardLayer): """A dummy FFN with aux loss."""