Skip to content

Commit

Permalink
#0: fix split qkv
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Nov 1, 2024
1 parent 0fa3566 commit 2b7fddf
Showing 1 changed file with 0 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,10 @@ def test_split_query_key_value_and_split_heads_with_program_cache(device, dtype,

tt_q = ttnn.sharded_to_interleaved(q, out_mem_config)
tt_q = tt_q.cpu().to_torch().float()
tt_q = untilize(tt_q)
tt_k = ttnn.sharded_to_interleaved(k, out_mem_config)
tt_k = tt_k.cpu().to_torch().float()
tt_k = untilize(tt_k)
tt_v = ttnn.sharded_to_interleaved(v, out_mem_config)
tt_v = tt_v.cpu().to_torch().float()
tt_v = untilize(tt_v)

fused_qkv_heads = torch.split(in0, input_shape[-1] // grid_size[1], dim=-1)
ref_q_list = []
Expand Down

0 comments on commit 2b7fddf

Please sign in to comment.