diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index 2b2c0ff7f33..dc03d7bca85 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -271,7 +271,7 @@ def paged_flash_attention_kernel( num_kv_pages_per_compute_block: int, mask_value: float, query_len: int, - attn_logits_soft_cap: float | None = None, + attn_logits_soft_cap: float | None, ): """Pallas kernel for paged attention.""" b, kv_head_idx, q_blk_idx, kv_blk_idx = (