Skip to content

Commit

Permalink
Supports loss_weights and live_targets in metrics. (#960)
Browse files Browse the repository at this point in the history
* Supports loss_weights, live_targets, and module sharing in metrics.

* Addresses comments.

* Explicitly test flatten_metrics=True.
  • Loading branch information
markblee authored Jan 31, 2025
1 parent 7a40f91 commit 0936a17
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 14 deletions.
30 changes: 25 additions & 5 deletions axlearn/common/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@ def forward(

target_labels: Tensor = input_batch["target_labels"]
target_num_bytes: Optional[Tensor] = input_batch.get("target_num_bytes")
live_targets: Optional[Tensor] = input_batch.get("live_targets")
logits = predict_outputs["logits"]

live_targets = target_labels >= 0
if live_targets is None:
live_targets = target_labels >= 0
num_targets = live_targets.sum()
accuracy = (
jnp.equal(jnp.argmax(logits, axis=-1), target_labels) * live_targets
Expand Down Expand Up @@ -226,16 +228,28 @@ class CompositeLossMetrics(BaseLossMetrics):

@config_class
class Config(BaseLossMetrics.Config):
"""Configures CompositeLossMetrics."""
"""Configures CompositeLossMetrics.
Attributes:
metrics: A mapping from child name to metrics config.
loss_weights: An optional mapping from child name to loss weight.
If None, all weights are considered 1.
flatten_metrics: Whether to flatten summaries and metrics from each child. If None,
defaults to True.
"""

metrics: Required[dict[str, BaseLossMetrics.Config]] = REQUIRED
loss_weights: Optional[dict[str, float]] = None
flatten_metrics: Optional[bool] = None

def __init__(self, cfg, *, parent):
super().__init__(cfg, parent=parent)
cfg: CompositeLossMetrics.Config = self.config
self._metrics: dict[str, BaseLossMetrics] = {}
for name, child in cfg.metrics.items():
self._metrics[name] = self._add_child(name, child)
if cfg.loss_weights is not None and cfg.metrics.keys() != cfg.loss_weights.keys():
raise ValueError(f"Expected {cfg.loss_weights.keys()=} to match {cfg.metrics.keys()}.")

def forward(
self,
Expand All @@ -249,6 +263,7 @@ def forward(
By default, losses are summed and metrics/summaries are flattened, raising if any keys
conflict.
"""
cfg: CompositeLossMetrics.Config = self.config
loss = 0
metrics = {}

Expand All @@ -258,12 +273,17 @@ def forward(
predict_outputs=predict_outputs,
module_outputs=module_outputs,
)
if cfg.loss_weights is not None:
child_loss *= cfg.loss_weights[name]
loss = loss + child_loss

ctx = self.get_invocation_context()
# Flatten summaries for backwards compatibility.
_update(ctx.output_collection.summaries, ctx.output_collection.summaries.pop(name))
_update(metrics, child_metrics)

if cfg.flatten_metrics is False:
_update(metrics, {name: child_metrics})
else:
_update(ctx.output_collection.summaries, ctx.output_collection.summaries.pop(name))
_update(metrics, child_metrics)

return loss, metrics

Expand Down
138 changes: 129 additions & 9 deletions axlearn/common/causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def forward(self, *args, **kwargs):
self.add_summary(f"{self.name}_summary", 0)
self.add_module_output(f"{self.name}_output", 0)
self.add_state_update(f"{self.name}_state", 0)
return 0, {}
return 0, {f"{self.name}_output": 0}

class DummyConflictModel(causal_lm.Model):
def _metrics(self, *args, **kwargs):
Expand Down Expand Up @@ -383,11 +383,16 @@ def _metrics(self, *args, **kwargs):
self.add_state_update("parent_state", 1)
return super()._metrics(*args, **kwargs)

def test_no_conflict(metrics_cfg: BaseLossMetrics.Config, expected: OutputCollection):
_, output_collection = forward(DummyModel.default_config(), metrics_cfg)
self.assertNestedEqual(output_collection.summaries, expected.summaries)
self.assertNestedEqual(output_collection.module_outputs, expected.module_outputs)
self.assertNestedEqual(output_collection.state_updates, expected.state_updates)
def test_no_conflict(
metrics_cfg: BaseLossMetrics.Config,
expected_oc: OutputCollection,
expected_metrics: dict,
):
(_, metrics), output_collection = forward(DummyModel.default_config(), metrics_cfg)
self.assertNestedEqual(output_collection.summaries, expected_oc.summaries)
self.assertNestedEqual(output_collection.module_outputs, expected_oc.module_outputs)
self.assertNestedEqual(output_collection.state_updates, expected_oc.state_updates)
self.assertNestedEqual(metrics, expected_metrics)

test_no_conflict(
DummyMetrics.default_config(),
Expand All @@ -396,16 +401,46 @@ def test_no_conflict(metrics_cfg: BaseLossMetrics.Config, expected: OutputCollec
module_outputs={"parent_output": 1, "metrics": {"metrics_output": 0}},
state_updates={"parent_state": 1, "metrics": {"metrics_state": 0}},
),
)
{"metrics_output": 0},
)
for flatten_metrics in (None, True):
test_no_conflict(
causal_lm.CompositeLossMetrics.default_config().set(
metrics={
"child1": DummyMetrics.default_config(),
"child2": DummyMetrics.default_config(),
},
flatten_metrics=flatten_metrics,
),
OutputCollection(
summaries={"parent_summary": 1, "child1_summary": 0, "child2_summary": 0},
module_outputs={
"parent_output": 1,
"metrics": {"child1": {"child1_output": 0}, "child2": {"child2_output": 0}},
},
state_updates={
"parent_state": 1,
"metrics": {"child1": {"child1_state": 0}, "child2": {"child2_state": 0}},
},
),
{"child1_output": 0, "child2_output": 0},
)

# Test without flattening.
test_no_conflict(
causal_lm.CompositeLossMetrics.default_config().set(
metrics={
"child1": DummyMetrics.default_config(),
"child2": DummyMetrics.default_config(),
}
},
flatten_metrics=False,
),
OutputCollection(
summaries={"parent_summary": 1, "child1_summary": 0, "child2_summary": 0},
summaries={
"parent_summary": 1,
"child1": {"child1_summary": 0},
"child2": {"child2_summary": 0},
},
module_outputs={
"parent_output": 1,
"metrics": {"child1": {"child1_output": 0}, "child2": {"child2_output": 0}},
Expand All @@ -415,6 +450,7 @@ def test_no_conflict(metrics_cfg: BaseLossMetrics.Config, expected: OutputCollec
"metrics": {"child1": {"child1_state": 0}, "child2": {"child2_state": 0}},
},
),
{"child1": {"child1_output": 0}, "child2": {"child2_output": 0}},
)

# TODO(markblee): Add a pytest marker for multi-device tests.
Expand Down Expand Up @@ -477,6 +513,90 @@ def fn(x):
)


class CrossEntropyLossMetricsTest(TestCase):
"""Tests CrossEntropyLossMetrics."""

def test_live_targets(self):
batch_size, seq_len, vocab_size = 3, 10, 10
tgt_key, logit_key, live_tgt_key = jax.random.split(jax.random.PRNGKey(0), num=3)
target_labels = jax.random.randint(
tgt_key, shape=[batch_size, seq_len], minval=-1, maxval=vocab_size
)
logits = jax.random.uniform(logit_key, shape=[*target_labels.shape, vocab_size])
layer = (
causal_lm.CrossEntropyLossMetrics.default_config()
.set(name="test")
.instantiate(parent=None)
)

# Make sure at least one masked target.
assert jnp.any(target_labels == -1), target_labels

def forward(live_targets):
(loss, metrics), _ = functional(
layer,
prng_key=None,
state={},
inputs=dict(
input_batch=dict(target_labels=target_labels, live_targets=live_targets),
predict_outputs=dict(logits=logits),
module_outputs={},
),
is_training=True,
)
return loss, metrics

# Test without live_targets. Should be equivalent to target_labels >= 0.
test_loss, metrics = forward(live_targets=None)
ref_loss, _ = cross_entropy(logits, target_labels, live_targets=target_labels >= 0)
self.assertAlmostEqual(test_loss, ref_loss)
self.assertEqual(metrics["num_targets"], (target_labels >= 0).sum())

# Test with live_targets.
live_targets = jax.random.randint(
live_tgt_key, shape=target_labels.shape, minval=0, maxval=2
)
test_loss, metrics = forward(live_targets=live_targets)
ref_loss, _ = cross_entropy(logits, target_labels, live_targets=live_targets)
self.assertAlmostEqual(test_loss, ref_loss)
self.assertEqual(metrics["num_targets"], live_targets.sum())


class CompositeLossMetricsTest(TestCase):
"""Tests CompositeLossMetrics."""

def test_loss_weights(self):
class DummyMetrics(BaseLossMetrics):
def forward(self, input_batch, **kwargs):
del kwargs
return input_batch[self.name], {}

cfg = causal_lm.CompositeLossMetrics.default_config().set(
name="test",
metrics={
"test0": DummyMetrics.default_config(),
"test1": DummyMetrics.default_config(),
},
)

# Test mismatched keys.
with self.assertRaisesRegex(ValueError, "keys"):
cfg.set(loss_weights={"test0": 0.5}).instantiate(parent=None)

metrics = cfg.set(loss_weights={"test0": 0.5, "test1": 1.0}).instantiate(parent=None)

(loss, _), _ = functional(
metrics,
prng_key=jax.random.PRNGKey(123),
state={},
inputs=dict(
input_batch={"test0": 1.23, "test1": 3.45}, predict_outputs={}, module_outputs={}
),
is_training=True,
)
self.assertAlmostEqual(loss, 1.23 * 0.5 + 3.45)


class DummyFeedForwardWithAuxLoss(TransformerFeedForwardLayer):
"""A dummy FFN with aux loss."""

Expand Down

0 comments on commit 0936a17

Please sign in to comment.