From 795f28ac552eb61d02ea02fd64637ba814133bd8 Mon Sep 17 00:00:00 2001 From: chilli Date: Sun, 1 Dec 2024 13:01:41 -0800 Subject: [PATCH] Ensure that BlockMask length must always exactly match the sequence length in flex_attention (#141625) Fixes https://github.com/pytorch/pytorch/issues/141435 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141625 Approved by: https://github.com/drisspg ghstack dependencies: #138788 --- test/inductor/test_flex_attention.py | 149 ++++++++++-------- test/inductor/test_flex_decoding.py | 14 +- torch/_higher_order_ops/flex_attention.py | 13 +- torch/_inductor/kernel/flex_attention.py | 21 +-- torch/_inductor/kernel/flex_decoding.py | 2 + .../experimental/_paged_attention.py | 2 + torch/nn/attention/flex_attention.py | 88 +++++++---- 7 files changed, 171 insertions(+), 118 deletions(-) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index b433621f8b308..288933fae8ece 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -668,62 +668,64 @@ def run_dynamic_test( D: int = D, ): score_mod, mask_mod = score_mask_mod - # If the seqlen becomes smaller than the seqlen of the previous batch, - # we can still reuse the block_mask created from a larger seqlen. - MAX_S = S - block_mask = create_block_mask(mask_mod, 1, 1, MAX_S, MAX_S) - sdpa_partial = create_attention(score_mod, block_mask=block_mask) - # The first eager batch, shape (B, H, S, D) + + # First batch with original dimensions (B, H, S, D) + block_mask1 = create_block_mask(mask_mod, 1, 1, S, S) + sdpa_partial1 = create_attention(score_mod, block_mask=block_mask1) + q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1) q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64) - ref_out1 = sdpa_partial(q1_ref, k1_ref, v1_ref) - golden_out1 = sdpa_partial(q1_gold, k1_gold, v1_gold) + ref_out1 = sdpa_partial1(q1_ref, k1_ref, v1_ref) + golden_out1 = sdpa_partial1(q1_gold, k1_gold, v1_gold) backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out1.backward(backward_grad1.to(torch.float64)) ref_out1.backward(backward_grad1) - # The second eager batch, shape (B * 2, H, S / 2, D) + # Second batch with modified dimensions (B * 2, H, S / 2, D) B = int(B * 2) S = int(S / 2) + block_mask2 = create_block_mask(mask_mod, 1, 1, S, S) + sdpa_partial2 = create_attention(score_mod, block_mask=block_mask2) + q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2) q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64) - ref_out2 = sdpa_partial(q2_ref, k2_ref, v2_ref) - golden_out2 = sdpa_partial(q2_gold, k2_gold, v2_gold) + ref_out2 = sdpa_partial2(q2_ref, k2_ref, v2_ref) + golden_out2 = sdpa_partial2(q2_gold, k2_gold, v2_gold) backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out2.backward(backward_grad2.to(torch.float64)) ref_out2.backward(backward_grad2) - # The third eager batch, shape (B * 2, H, S / 4, D) + # Third batch with modified dimensions (B * 2, H, S / 4, D) S = int(S / 2) + block_mask3 = create_block_mask(mask_mod, 1, 1, S, S) + sdpa_partial3 = create_attention(score_mod, block_mask=block_mask3) + q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) q3_ref, k3_ref, v3_ref = query_key_value_clones(q3, k3, v3) q3_gold, k3_gold, v3_gold = query_key_value_clones(q3, k3, v3, torch.float64) - ref_out3 = sdpa_partial(q3_ref, k3_ref, v3_ref) - golden_out3 = sdpa_partial(q3_gold, k3_gold, v3_gold) + ref_out3 = sdpa_partial3(q3_ref, k3_ref, v3_ref) + golden_out3 = sdpa_partial3(q3_gold, k3_gold, v3_gold) backward_grad3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out3.backward(backward_grad3.to(torch.float64)) ref_out3.backward(backward_grad3) - # Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing. - # We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation. + # Clear dynamo counters torch._dynamo.reset() - # Compiling with dynamic shape in the first batch. - compiled_sdpa = torch.compile(sdpa_partial, dynamic=True) - compiled_out1 = compiled_sdpa(q1, k1, v1) + + # First compilation with original dimensions + compiled_sdpa1 = torch.compile(sdpa_partial1, dynamic=True) + compiled_out1 = compiled_sdpa1(q1, k1, v1) compiled_out1.backward(backward_grad1) self._check_out_and_grad( @@ -742,10 +744,11 @@ def run_dynamic_test( ) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) - # Since current q_seqlen (MAX_S/2) is smaller than the seqlen from block_mask (MAX_S), - # recompile to include the BlockMask._adjust part. - compiled_out2 = compiled_sdpa(q2, k2, v2) + # Second compilation with new dimensions + compiled_sdpa2 = torch.compile(sdpa_partial2, dynamic=True) + compiled_out2 = compiled_sdpa2(q2, k2, v2) compiled_out2.backward(backward_grad2) + self._check_out_and_grad( golden_out2, ref_out2, @@ -760,13 +763,13 @@ def run_dynamic_test( v2_ref, v2, ) - self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) + self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) - # No re-compilation, use the compiled dynamic shape version. - # The current q_seqlen (MAX_S/4) is still smaller than the seqlen from block_mask (MAX_S), - # we don't recompile since we can reuse the compiled graph, which already includes the BlockMask._adjust part. - compiled_out3 = compiled_sdpa(q3, k3, v3) + # Third compilation with new dimensions + compiled_sdpa3 = torch.compile(sdpa_partial3, dynamic=True) + compiled_out3 = compiled_sdpa3(q3, k3, v3) compiled_out3.backward(backward_grad3) + self._check_out_and_grad( golden_out3, ref_out3, @@ -781,18 +784,7 @@ def run_dynamic_test( v3_ref, v3, ) - self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) - - # The forth iteration, shape (B * 2, H, S * 2, D) - # Since seqlen is larger than the seqlen in block_mask, throw errors. - S = int(S * 8) - q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) - with self.assertRaisesRegex( - torch._dynamo.exc.BackendCompilerFailed, "Q seqlen must be smaller than" - ): - compiled_sdpa(q3, k3, v3) + self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) def run_automatic_dynamic_test( self, @@ -804,38 +796,42 @@ def run_automatic_dynamic_test( D: int = D, ): MAX_S = S - block_mask = create_block_mask(noop_mask, 1, 1, MAX_S, MAX_S) - sdpa_partial = create_attention(score_mod, block_mask=block_mask) + block_mask1 = create_block_mask(noop_mask, 1, 1, S, S) + sdpa_partial1 = create_attention(score_mod, block_mask=block_mask1) # The first eager batch, shape (B, H, S, D) q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out1 = sdpa_partial( + golden_out1 = sdpa_partial1( q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64) ) - ref_out1 = sdpa_partial(q1, k1, v1) + ref_out1 = sdpa_partial1(q1, k1, v1) # The second eager batch, shape (B * 2, H, S / 2, D) B = int(B * 2) S = int(S / 2) + block_mask2 = create_block_mask(noop_mask, 1, 1, S, S) + sdpa_partial2 = create_attention(score_mod, block_mask=block_mask2) q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out2 = sdpa_partial( + golden_out2 = sdpa_partial2( q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64) ) - ref_out2 = sdpa_partial(q2, k2, v2) + ref_out2 = sdpa_partial2(q2, k2, v2) # The third eager batch, shape (B * 4, H, S / 4, D) B = int(B * 2) S = int(S / 2) + block_mask3 = create_block_mask(noop_mask, 1, 1, S, S) + sdpa_partial3 = create_attention(score_mod, block_mask=block_mask3) q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out3 = sdpa_partial( + golden_out3 = sdpa_partial3( q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64) ) - ref_out3 = sdpa_partial(q3, k3, v3) + ref_out3 = sdpa_partial3(q3, k3, v3) # Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing. # We check dynamo counters["frames"]["ok"] to ensure: @@ -852,18 +848,17 @@ def run_automatic_dynamic_test( fudge_factor = 1.1 # The first batch. - compiled_sdpa = torch.compile(sdpa_partial) - compiled_out1 = compiled_sdpa(q1, k1, v1) + compiled_out1 = torch.compile(sdpa_partial1)(q1, k1, v1) self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) # The second batch (automatic dynamic). - compiled_out2 = compiled_sdpa(q2, k2, v2) + compiled_out2 = torch.compile(sdpa_partial2)(q2, k2, v2) self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) # The third batch (no re-compilation). - compiled_out3 = compiled_sdpa(q3, k3, v3) + compiled_out3 = torch.compile(sdpa_partial3)(q3, k3, v3) self._check_equal(golden_out3, ref_out3, compiled_out3, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) @@ -912,11 +907,6 @@ def causal_mask(b, h, q, kv): def test_builtin_score_mods_dynamic( self, dtype: torch.dtype, score_mask_mod: Tuple[Callable, Callable] ): - if score_mask_mod[0].__name__ == "_alibi_bias": - # TODO - self.skipTest( - "Alibi bias broken with dynamic shapes since we don't support capturing dynamic shapes" - ) self.run_dynamic_test(score_mask_mod, dtype) @supported_platform @@ -2203,7 +2193,7 @@ def test_differentiable_logsumexp_compiled(self): # Use weird mask to test reusing block_mask does work well. @supported_platform - def test_block_mask_reuse_with_weird_mask(self): + def _test_block_mask_reuse_with_weird_mask(self): def mask(b, h, q, kv): return (kv < 256) | (kv >= 2048) @@ -3231,12 +3221,12 @@ def causal_mask(b, h, q_idx, kv_idx): norm_graph, """\ class GraphModule(torch.nn.Module): - def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_value_: "f64[2, 2, 128, 4]", L_block_mask_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_full_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_full_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_q_num_blocks: "i32[1, 1, 1]", L_block_mask_q_indices: "i32[1, 1, 1, 1]", L_block_mask_full_q_num_blocks: "i32[1, 1, 1]", L_block_mask_full_q_indices: "i32[1, 1, 1, 1]"): + def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_value_: "f64[2, 2, 128, 4]", L_block_mask_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_full_kv_num_blocks: "i32[1, 1, 1]", L_block_mask_full_kv_indices: "i32[1, 1, 1, 1]", L_block_mask_q_num_blocks: "i32[1, 1, 1]", L_block_mask_q_indices: "i32[1, 1, 1, 1]", L_block_mask_full_q_num_blocks: "i32[1, 1, 1]", L_block_mask_full_q_indices: "i32[1, 1, 1, 1]"): l_query_ = L_query_ l_key_ = L_key_ l_value_ = L_value_ - l_block_mask_kv_num_blocks = L_block_mask_kv_num_blocks l_block_mask_kv_indices = L_block_mask_kv_indices + l_block_mask_kv_num_blocks = L_block_mask_kv_num_blocks l_block_mask_full_kv_num_blocks = L_block_mask_full_kv_num_blocks l_block_mask_full_kv_indices = L_block_mask_full_kv_indices l_block_mask_q_num_blocks = L_block_mask_q_num_blocks @@ -3246,7 +3236,7 @@ def forward(self, L_query_: "f64[2, 2, 128, 4]", L_key_: "f64[2, 2, 128, 4]", L_ score_mod_0 = self.score_mod_0 mask_fn_0 = self.mask_fn_0 - flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None + flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None return (out,) @@ -3287,7 +3277,7 @@ def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]" fw_graph0 = self.fw_graph0 joint_graph0 = self.joint_graph0 mask_graph0 = self.mask_graph0 - flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None + flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None getitem_4: "f64[2, 2, 128, 4]" = flex_attention_backward[0] getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[1] getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None @@ -3643,6 +3633,7 @@ def test_init_mismatched_full_kv(self): full_q_indices=None, BLOCK_SIZE=(64, 64), mask_mod=noop_mask, + seq_lengths=(1, 1), ) @supported_platform @@ -3662,6 +3653,7 @@ def test_init_mismatched_full_q(self): full_q_indices=None, # Mismatched, should raise error BLOCK_SIZE=(64, 64), mask_mod=noop_mask, + seq_lengths=(1, 1), ) @supported_platform @@ -3787,6 +3779,35 @@ def doc_mask_mod(b, h, q_idx, kv_idx): block_mask = create_block_mask(doc_mask_mod, None, None, 1024 + i, 1024 + i) torch.compile(flex_attention)(q, k, v, block_mask=block_mask) + @common_utils.parametrize("compile", [False, True]) + @supported_platform + def test_block_mask_vs_sequence_lengths(self, compile): + if compile: + flex_attention_call = torch.compile(flex_attention) + else: + flex_attention_call = flex_attention + + def mask_mod(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + def create_inputs(S): + q, k, v = ( + torch.randn( + 1, 8, S, 64, dtype=torch.float16, requires_grad=True, device="cuda" + ) + for _ in range(3) + ) + return q, k, v + + block_mask = create_block_mask(mask_mod, None, None, 1024, 1024) + flex_attention_call(*create_inputs(1024), block_mask=block_mask) + with self.assertRaisesRegex(ValueError, "block_mask was created for"): + flex_attention_call(*create_inputs(2048), block_mask=block_mask) + + block_mask = create_block_mask(mask_mod, None, None, 1023, 1023) + with self.assertRaisesRegex(ValueError, "block_mask was created for"): + flex_attention_call(*create_inputs(1024), block_mask=block_mask) + class TestPagedAttention(InductorTestCase): def _check_equal( diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 4ae4cb34feb55..1e8c0ada855f0 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -579,9 +579,9 @@ def run_test_with_call_paged_attention( ref_out = golden_call(q_ref, k_ref, v_ref) if mask_mod is not None: - block_mask = create_block_mask(mask_mod, Q_B, 1, 1, S) + block_mask = create_block_mask(mask_mod, Q_B, 1, Q_S, KV_S) else: - block_mask = create_block_mask(noop_mask, Q_B, 1, 1, S) + block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S) compiled_out, _ = self.run_paged_attention( score_mod, q, k, v, dtype, block_mask @@ -682,7 +682,7 @@ def test_builtin_score_mods_different_block_size( score_mod: Callable, BLOCK_SIZE: Union[int, Tuple[int, int]], ): - block_mask = create_block_mask(noop_mask, B, 1, S, S, BLOCK_SIZE=BLOCK_SIZE) + block_mask = create_block_mask(noop_mask, B, 1, 1, S, BLOCK_SIZE=BLOCK_SIZE) self.run_test(score_mod, dtype, block_mask=block_mask) def input_strides_1(B, H, S, D): @@ -1098,7 +1098,7 @@ def scoremod_1(qk, b, h, q, kv): def scoremod_2(qk, b, h, q, kv): return torch.where(q >= kv, qk, -float("inf")) - block_mask = create_block_mask(noop_mask, 1, 1, 1, S) + block_mask = create_block_mask(noop_mask, 1, 1, 4, 1024) def f(q, k1, k2, v1, v2): q2 = flex_attention(q, k1, v1, score_mod=scoremod_1, block_mask=block_mask) @@ -1167,7 +1167,7 @@ def scoremod_1(qk, b, h, q, kv): def scoremod_2(qk, b, h, q, kv): return torch.where(q >= kv, qk, -float("inf")) - block_mask = create_block_mask(noop_mask, 1, 1, 1, S) + block_mask = create_block_mask(noop_mask, 1, 1, 4, 1024) attention1 = functools.partial( flex_attention, score_mod=scoremod_1, block_mask=block_mask @@ -1567,8 +1567,8 @@ def mask_mod(b, h, q_idx, kv_idx): mask_mod=mask_mod, B=2, H=None, - Q_LEN=128, - KV_LEN=256, + Q_LEN=2, + KV_LEN=2, device="cuda", ) diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 045d7c98ae20f..e15e122fddb99 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -614,7 +614,7 @@ def forward( value, out, logsumexp, - *block_mask[:10], + *block_mask[:-1], *score_mod_other_buffers, *mask_mod_other_buffers, ), @@ -630,6 +630,8 @@ def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Option value, out, logsumexp, + query_lengths, + kv_lengths, kv_num_blocks, kv_indices, full_kv_num_blocks, @@ -672,6 +674,8 @@ def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Option fw_graph, joint_graph, ( + query_lengths, + kv_lengths, kv_num_blocks, kv_indices, full_kv_num_blocks, @@ -708,7 +712,8 @@ def flex_attention_autograd( with TransformGetItemToIndex(): input_requires_grad = any( - t.requires_grad for t in (query, key, value, *score_mod_other_buffers) + isinstance(t, torch.Tensor) and t.requires_grad + for t in (query, key, value, *score_mod_other_buffers) ) if torch.is_grad_enabled() and input_requires_grad: example_vals = ( @@ -1130,7 +1135,9 @@ def flex_attention_backward_fake_tensor_mode( grad_value = torch.empty_like(value) grad_score_mod_captured = tuple( [ - torch.empty_like(buffer) if buffer.requires_grad else None + torch.empty_like(buffer) + if isinstance(buffer, torch.Tensor) and buffer.requires_grad + else None for buffer in score_mod_other_buffers ] ) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index e38c01fc27f5f..c05df7cd6b0f9 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -810,6 +810,8 @@ def flex_attention( mask_mod_other_buffers, ): ( + _, # q_length + _, # kv_length kv_num_blocks, kv_indices, full_kv_num_blocks, @@ -968,12 +970,6 @@ def flex_attention( # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) - assert V.graph.sizevars.evaluate_expr( - sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE)) - ), "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask." - assert V.graph.sizevars.evaluate_expr( - sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE)) - ), "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask." # Note, we don't need to pass in the captured buffers explicitly # because they're implicitly added by the score_mod function @@ -1509,7 +1505,7 @@ def bwd_dq_block_mn( ) | indent_except_first(2) }} if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf")) + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) # apply mask for partial masked block post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1541,7 +1537,7 @@ def bwd_dq_block_mn( if not IS_FULL_BLOCKS: if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, float("-inf")) + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) # (grads) apply mask for partially unmasked block ds = tl.where(mask_mod_output, ds, 0.0) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1691,7 +1687,7 @@ def bwd_dkdv_block_mn( n="n", ) | indent_except_first(2) }} if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf")) + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) # (grads) apply mask for fully masked block post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1749,7 +1745,7 @@ def bwd_dkdv_block_mn( dsT = grad_scores if not IS_FULL_BLOCKS: if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, float("-inf")) + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) # (grads) apply mask for partially unmasked block dsT = tl.where(mask_mod_output, dsT, 0.0) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1860,6 +1856,8 @@ def flex_attention_backward(*args, **kwargs): mask_mod_other_buffers, ) = args ( + _, # q_length + _, # kv_length kv_num_blocks, kv_indices, full_kv_num_blocks, @@ -2036,6 +2034,9 @@ def flex_attention_backward(*args, **kwargs): or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0 ): continue + if num_warps == 8: + # Working around https://github.com/pytorch/pytorch/issues/141603 + continue # Performance tuning cur_kernel_options = original_kernel_options.copy() diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index c4608083253ac..c1d99e2593699 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -332,6 +332,8 @@ def create_flex_decoding_kernel(*args, **kwargs): mask_mod_other_buffers, ) = args ( + _, # q_length + _, # kv_length kv_num_blocks, kv_indices, full_kv_num_blocks, # full_kv_num_blocks, diff --git a/torch/nn/attention/experimental/_paged_attention.py b/torch/nn/attention/experimental/_paged_attention.py index 2c4c5f302dfee..a0cd5c1893b61 100644 --- a/torch/nn/attention/experimental/_paged_attention.py +++ b/torch/nn/attention/experimental/_paged_attention.py @@ -264,6 +264,7 @@ def convert_logical_block_mask( new_mask_mod = self.get_mask_mod(block_mask.mask_mod) + seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size) return BlockMask.from_kv_blocks( new_kv_num_blocks, new_kv_indices, @@ -271,6 +272,7 @@ def convert_logical_block_mask( new_full_kv_indices, block_mask.BLOCK_SIZE, new_mask_mod, + seq_lengths=seq_lengths, ) def get_mask_mod( diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 00774fac7a760..8aef1da0e50c5 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -262,6 +262,7 @@ class BlockMask: the backwards pass. These are autogenerated from 2. """ + seq_lengths: Tuple[int, int] kv_num_blocks: Tensor kv_indices: Tensor full_kv_num_blocks: Optional[Tensor] @@ -275,6 +276,7 @@ class BlockMask: def __init__( self, + seq_lengths: Tuple[int, int], kv_num_blocks: Tensor, kv_indices: Tensor, full_kv_num_blocks: Optional[Tensor], @@ -299,6 +301,7 @@ def __init__( full_q_indices is None ), "full_q_num_blocks and full_q_indices must be both provided or omitted" + self.seq_lengths = seq_lengths self.kv_num_blocks = kv_num_blocks self.kv_indices = kv_indices self.full_kv_num_blocks = full_kv_num_blocks @@ -319,6 +322,7 @@ def from_kv_blocks( full_kv_indices: Optional[Tensor] = None, BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE, mask_mod: Optional[_mask_mod_signature] = None, + seq_lengths: Optional[Tuple[int, int]] = None, ): """ Creates a BlockMask instance from key-value block information. @@ -359,8 +363,13 @@ def from_kv_blocks( BLOCK_SIZE = (BLOCK_SIZE, BLOCK_SIZE) mask_mod = mask_mod if mask_mod is not None else noop_mask + if seq_lengths is None: + q_length = kv_indices.shape[-2] * BLOCK_SIZE[0] + kv_length = q_indices.shape[-2] * BLOCK_SIZE[1] + seq_lengths = (q_length, kv_length) return cls( + seq_lengths=seq_lengths, kv_num_blocks=kv_num_blocks, kv_indices=kv_indices, full_kv_num_blocks=full_kv_num_blocks, @@ -380,11 +389,15 @@ def as_tuple(self, flatten: bool = True): Args: flatten (bool): If True, it will flatten the tuple of (KV_BLOCK_SIZE, Q_BLOCK_SIZE) """ - block_size = ( - (self.BLOCK_SIZE[0], self.BLOCK_SIZE[1]) if flatten else (self.BLOCK_SIZE,) - ) + if flatten: + block_size = (self.BLOCK_SIZE[0], self.BLOCK_SIZE[1]) # type: ignore[assignment] + seq_lengths = (self.seq_lengths[0], self.seq_lengths[1]) # type: ignore[assignment] + else: + block_size = (self.BLOCK_SIZE,) # type: ignore[assignment] + seq_lengths = (self.seq_lengths,) # type: ignore[assignment] return ( + *seq_lengths, self.kv_num_blocks, self.kv_indices, self.full_kv_num_blocks, @@ -397,6 +410,11 @@ def as_tuple(self, flatten: bool = True): self.mask_mod, ) + @property + def shape(self): + *batch_dims, _, _ = self.kv_indices.shape + return tuple(batch_dims) + self.seq_lengths + def __str__(self): s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n" mask_str = self.to_string().strip() @@ -457,6 +475,7 @@ def causal_mask(b, h, q_idx, kv_idx): new_full_kv_indices, BLOCK_SIZE=self.BLOCK_SIZE, mask_mod=None, + seq_lengths=self.seq_lengths, ) def __repr__(self): @@ -509,14 +528,6 @@ def _adjust(self, new_q_len: int, new_kv_len: int): self.mask_mod, ) - @property - def shape(self): - """Returns the shape of the mask.""" - *batch_dims, q_length, _ = self.kv_indices.shape - q_length = self.kv_indices.shape[-2] * self.BLOCK_SIZE[0] - kv_length = self.kv_indices.shape[-1] * self.BLOCK_SIZE[1] - return tuple(batch_dims + [q_length, kv_length]) - def numel(self): """Returns the number of elements (not accounting for sparsity) in the mask.""" shape = self.shape @@ -739,6 +750,7 @@ def _convert_block_mask_to_mask( def _create_sparse_block_from_block_mask( block_mask: Tuple[Tensor, Optional[Tensor]], mask_mod: Optional[Callable], + seq_lengths: Tuple[int, int], Q_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE, KV_BLOCK_SIZE: int = _DEFAULT_SPARSE_BLOCK_SIZE, ) -> BlockMask: @@ -757,6 +769,7 @@ def _create_sparse_block_from_block_mask( full_bm[1], BLOCK_SIZE=(Q_BLOCK_SIZE, KV_BLOCK_SIZE), mask_mod=mask_mod, + seq_lengths=seq_lengths, ) @@ -878,7 +891,11 @@ def causal_mask(b, h, q_idx, kv_idx): separate_full_blocks=True, ) block_mask = _create_sparse_block_from_block_mask( - (partial_block_mask, full_block_mask), mask_mod, Q_BLOCK_SIZE, KV_BLOCK_SIZE + (partial_block_mask, full_block_mask), + mask_mod, + (Q_LEN, KV_LEN), + Q_BLOCK_SIZE, + KV_BLOCK_SIZE, ) return block_mask @@ -894,6 +911,7 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask: kv_num_blocks=torch.ones([1, 1, 1], dtype=torch.int32, device=device), kv_indices=torch.zeros([1, 1, 1, 1], dtype=torch.int32, device=device), BLOCK_SIZE=_LARGE_SPARSE_BLOCK_SIZE, + seq_lengths=(1, 1), ) @@ -1237,29 +1255,31 @@ def score_mod( if block_mask is None: block_mask = _create_empty_block_mask(query, key) - elif ( - not query.is_nested - and (query.requires_grad or key.requires_grad or value.requires_grad) - and ( - query.size(-2) - < block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0] - or key.size(-2) < block_mask.kv_indices.size(-1) * block_mask.BLOCK_SIZE[1] - ) - ): - new_q_len = _round_up_to_multiple(query.size(-2), block_mask.BLOCK_SIZE[0]) - new_kv_len = _round_up_to_multiple(key.size(-2), block_mask.BLOCK_SIZE[1]) - block_mask = block_mask._adjust(new_q_len, new_kv_len) - elif query.is_nested and ( - block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0] - != _round_up_to_multiple( - query._values.size(query._ragged_idx - 1), block_mask.BLOCK_SIZE[0] # type: ignore[attr-defined] - ) + + if ( + block_mask.BLOCK_SIZE[0] == _LARGE_SPARSE_BLOCK_SIZE + and block_mask.BLOCK_SIZE[1] == _LARGE_SPARSE_BLOCK_SIZE ): - # TODO: Maybe we want to auto-adjust for this case as well? - raise RuntimeError( - f"block_mask of shape {block_mask.shape} is not compatible with nested tensor input " - f"with total sequence length of {query._values.size(query._ragged_idx - 1)}" # type: ignore[attr-defined] - ) + # This corresponds to the case where we essentially have a "no-op" block mask. + pass + else: + block_mask_q_len = block_mask.shape[-2] + block_mask_kv_len = block_mask.shape[-1] + if query.size(-2) > block_mask_q_len or key.size(-2) > block_mask_kv_len: + raise ValueError( + f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. " + "As the block mask was created for a smaller length than you're using it for, you likely need to create a new block mask." + ) + elif ( + query.size(-2) < block_mask_q_len and key.size(-2) <= block_mask_kv_len + ) or (query.size(-2) <= block_mask_q_len and key.size(-2) < block_mask_kv_len): + raise ValueError( + f"block_mask was created for block_mask.shape={block_mask.shape} but got q_len={query.size(-2)} and kv_len={key.size(-2)}. " + "As the block mask was created for a larger length than you're using it for, you can either 1. create a new block mask with the correct length, or 2. 'adjust' the existing block mask to the correct length by calling block_mask._adjust(q_len, kv_len). This essentially 'crops' the block mask to the upper left corner, which does not work for all mask_mods!" + ) + assert query.size(-2) == block_mask_q_len + assert key.size(-2) == block_mask_kv_len + if scale is None: scale = 1.0 / math.sqrt(query.size(-1))