Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supports loss_weights and live_targets in metrics. #960

Merged
merged 3 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions axlearn/common/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@ def forward(

target_labels: Tensor = input_batch["target_labels"]
target_num_bytes: Optional[Tensor] = input_batch.get("target_num_bytes")
live_targets: Optional[Tensor] = input_batch.get("live_targets")
logits = predict_outputs["logits"]

live_targets = target_labels >= 0
if live_targets is None:
live_targets = target_labels >= 0
num_targets = live_targets.sum()
accuracy = (
jnp.equal(jnp.argmax(logits, axis=-1), target_labels) * live_targets
Expand Down Expand Up @@ -226,16 +228,28 @@ class CompositeLossMetrics(BaseLossMetrics):

@config_class
class Config(BaseLossMetrics.Config):
"""Configures CompositeLossMetrics."""
"""Configures CompositeLossMetrics.

Attributes:
metrics: A mapping from child name to metrics config.
loss_weights: An optional mapping from child name to loss weight.
If None, all weights are considered 1.
flatten_metrics: Whether to flatten summaries and metrics from each child. If None,
defaults to True.
"""

metrics: Required[dict[str, BaseLossMetrics.Config]] = REQUIRED
loss_weights: Optional[dict[str, float]] = None
flatten_metrics: Optional[bool] = None

def __init__(self, cfg, *, parent):
super().__init__(cfg, parent=parent)
cfg: CompositeLossMetrics.Config = self.config
self._metrics: dict[str, BaseLossMetrics] = {}
for name, child in cfg.metrics.items():
self._metrics[name] = self._add_child(name, child)
if cfg.loss_weights is not None and cfg.metrics.keys() != cfg.loss_weights.keys():
raise ValueError(f"Expected {cfg.loss_weights.keys()=} to match {cfg.metrics.keys()}.")

def forward(
self,
Expand All @@ -249,6 +263,7 @@ def forward(
By default, losses are summed and metrics/summaries are flattened, raising if any keys
conflict.
"""
cfg: CompositeLossMetrics.Config = self.config
loss = 0
metrics = {}

Expand All @@ -258,12 +273,17 @@ def forward(
predict_outputs=predict_outputs,
module_outputs=module_outputs,
)
if cfg.loss_weights is not None:
child_loss *= cfg.loss_weights[name]
loss = loss + child_loss

ctx = self.get_invocation_context()
# Flatten summaries for backwards compatibility.
_update(ctx.output_collection.summaries, ctx.output_collection.summaries.pop(name))
_update(metrics, child_metrics)

if cfg.flatten_metrics is False:
_update(metrics, {name: child_metrics})
else:
_update(ctx.output_collection.summaries, ctx.output_collection.summaries.pop(name))
_update(metrics, child_metrics)

return loss, metrics

Expand Down
138 changes: 129 additions & 9 deletions axlearn/common/causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def forward(self, *args, **kwargs):
self.add_summary(f"{self.name}_summary", 0)
self.add_module_output(f"{self.name}_output", 0)
self.add_state_update(f"{self.name}_state", 0)
return 0, {}
return 0, {f"{self.name}_output": 0}

class DummyConflictModel(causal_lm.Model):
def _metrics(self, *args, **kwargs):
Expand Down Expand Up @@ -383,11 +383,16 @@ def _metrics(self, *args, **kwargs):
self.add_state_update("parent_state", 1)
return super()._metrics(*args, **kwargs)

def test_no_conflict(metrics_cfg: BaseLossMetrics.Config, expected: OutputCollection):
_, output_collection = forward(DummyModel.default_config(), metrics_cfg)
self.assertNestedEqual(output_collection.summaries, expected.summaries)
self.assertNestedEqual(output_collection.module_outputs, expected.module_outputs)
self.assertNestedEqual(output_collection.state_updates, expected.state_updates)
def test_no_conflict(
metrics_cfg: BaseLossMetrics.Config,
expected_oc: OutputCollection,
expected_metrics: dict,
):
(_, metrics), output_collection = forward(DummyModel.default_config(), metrics_cfg)
self.assertNestedEqual(output_collection.summaries, expected_oc.summaries)
self.assertNestedEqual(output_collection.module_outputs, expected_oc.module_outputs)
self.assertNestedEqual(output_collection.state_updates, expected_oc.state_updates)
self.assertNestedEqual(metrics, expected_metrics)

test_no_conflict(
DummyMetrics.default_config(),
Expand All @@ -396,16 +401,46 @@ def test_no_conflict(metrics_cfg: BaseLossMetrics.Config, expected: OutputCollec
module_outputs={"parent_output": 1, "metrics": {"metrics_output": 0}},
state_updates={"parent_state": 1, "metrics": {"metrics_state": 0}},
),
)
{"metrics_output": 0},
)
for flatten_metrics in (None, True):
test_no_conflict(
causal_lm.CompositeLossMetrics.default_config().set(
metrics={
"child1": DummyMetrics.default_config(),
"child2": DummyMetrics.default_config(),
},
flatten_metrics=flatten_metrics,
),
OutputCollection(
summaries={"parent_summary": 1, "child1_summary": 0, "child2_summary": 0},
module_outputs={
"parent_output": 1,
"metrics": {"child1": {"child1_output": 0}, "child2": {"child2_output": 0}},
},
state_updates={
"parent_state": 1,
"metrics": {"child1": {"child1_state": 0}, "child2": {"child2_state": 0}},
},
),
{"child1_output": 0, "child2_output": 0},
)

# Test without flattening.
test_no_conflict(
causal_lm.CompositeLossMetrics.default_config().set(
metrics={
"child1": DummyMetrics.default_config(),
"child2": DummyMetrics.default_config(),
}
},
flatten_metrics=False,
),
OutputCollection(
summaries={"parent_summary": 1, "child1_summary": 0, "child2_summary": 0},
summaries={
"parent_summary": 1,
"child1": {"child1_summary": 0},
"child2": {"child2_summary": 0},
},
module_outputs={
"parent_output": 1,
"metrics": {"child1": {"child1_output": 0}, "child2": {"child2_output": 0}},
Expand All @@ -415,6 +450,7 @@ def test_no_conflict(metrics_cfg: BaseLossMetrics.Config, expected: OutputCollec
"metrics": {"child1": {"child1_state": 0}, "child2": {"child2_state": 0}},
},
),
{"child1": {"child1_output": 0}, "child2": {"child2_output": 0}},
)

# TODO(markblee): Add a pytest marker for multi-device tests.
Expand Down Expand Up @@ -477,6 +513,90 @@ def fn(x):
)


class CrossEntropyLossMetricsTest(TestCase):
"""Tests CrossEntropyLossMetrics."""

def test_live_targets(self):
batch_size, seq_len, vocab_size = 3, 10, 10
tgt_key, logit_key, live_tgt_key = jax.random.split(jax.random.PRNGKey(0), num=3)
target_labels = jax.random.randint(
tgt_key, shape=[batch_size, seq_len], minval=-1, maxval=vocab_size
)
logits = jax.random.uniform(logit_key, shape=[*target_labels.shape, vocab_size])
layer = (
causal_lm.CrossEntropyLossMetrics.default_config()
.set(name="test")
.instantiate(parent=None)
)

# Make sure at least one masked target.
assert jnp.any(target_labels == -1), target_labels

def forward(live_targets):
(loss, metrics), _ = functional(
layer,
prng_key=None,
state={},
inputs=dict(
input_batch=dict(target_labels=target_labels, live_targets=live_targets),
predict_outputs=dict(logits=logits),
module_outputs={},
),
is_training=True,
)
return loss, metrics

# Test without live_targets. Should be equivalent to target_labels >= 0.
test_loss, metrics = forward(live_targets=None)
ref_loss, _ = cross_entropy(logits, target_labels, live_targets=target_labels >= 0)
self.assertAlmostEqual(test_loss, ref_loss)
self.assertEqual(metrics["num_targets"], (target_labels >= 0).sum())

# Test with live_targets.
live_targets = jax.random.randint(
live_tgt_key, shape=target_labels.shape, minval=0, maxval=2
)
test_loss, metrics = forward(live_targets=live_targets)
ref_loss, _ = cross_entropy(logits, target_labels, live_targets=live_targets)
self.assertAlmostEqual(test_loss, ref_loss)
self.assertEqual(metrics["num_targets"], live_targets.sum())


class CompositeLossMetricsTest(TestCase):
"""Tests CompositeLossMetrics."""

def test_loss_weights(self):
class DummyMetrics(BaseLossMetrics):
def forward(self, input_batch, **kwargs):
del kwargs
return input_batch[self.name], {}

cfg = causal_lm.CompositeLossMetrics.default_config().set(
name="test",
metrics={
"test0": DummyMetrics.default_config(),
"test1": DummyMetrics.default_config(),
},
)

# Test mismatched keys.
with self.assertRaisesRegex(ValueError, "keys"):
cfg.set(loss_weights={"test0": 0.5}).instantiate(parent=None)

metrics = cfg.set(loss_weights={"test0": 0.5, "test1": 1.0}).instantiate(parent=None)

(loss, _), _ = functional(
metrics,
prng_key=jax.random.PRNGKey(123),
state={},
inputs=dict(
input_batch={"test0": 1.23, "test1": 3.45}, predict_outputs={}, module_outputs={}
),
is_training=True,
)
self.assertAlmostEqual(loss, 1.23 * 0.5 + 3.45)


class DummyFeedForwardWithAuxLoss(TransformerFeedForwardLayer):
"""A dummy FFN with aux loss."""

Expand Down