From 526902ff1987103f54d2dd175e9c53eede25ad76 Mon Sep 17 00:00:00 2001 From: "Zheng-Yong (Arsa) Ang" Date: Fri, 24 Jan 2025 21:22:51 -0800 Subject: [PATCH] explicitly disable pt2 compile on gauc (#2703) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2703 Context: the recent changes to GAUC have made it incompatible with PT2 compile, thus causing several issues with launching jobs on the latest trunk. After some digging, it turns out that the incompatibility lies with `torch.Tensor.item()`: https://github.com/pytorch/pytorch/issues/130917. This diff: explicitly disables PT2 compilation on the problematic function, thus preventing training jobs from crashing due to incompatibility with PT2 compile. Reviewed By: shz117 Differential Revision: D68629159 fbshipit-source-id: 7cac0c6ebe720f02033ac807a47315fb20996595 --- torchrec/metrics/gauc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrec/metrics/gauc.py b/torchrec/metrics/gauc.py index d6980a574..0edf9c529 100644 --- a/torchrec/metrics/gauc.py +++ b/torchrec/metrics/gauc.py @@ -100,6 +100,7 @@ def to_3d( return torch.ops.fbgemm.jagged_2d_to_dense(tensor_2d, offsets, max_length) +@torch.compiler.disable def get_auc_states( labels: torch.Tensor, predictions: torch.Tensor,