Skip to content

Commit

Permalink
update missed methods
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNijjar committed Jan 2, 2025
1 parent f080a5f commit fa868ef
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions ttnn/ttnn/distributed/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,12 @@ def __init__(self, mesh_device: MeshDevice, dim: int):
self.concat_dim = dim
self.mesh_device = mesh_device

def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor":
def compose(self, tensor: ttnn.Tensor, sub_device_ids: List[ttnn.SubDeviceId] = []) -> "torch.Tensor":
import torch

device_shards_converted_to_torch = [
ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor)
ttnn.to_torch(tt_input_tensor, mesh_composer=None, sub_device_ids=sub_device_ids)
for tt_input_tensor in ttnn.get_device_tensors(tensor)
]
return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim)

Expand All @@ -467,9 +468,10 @@ class ListMeshToTensor(MeshToTensor):
def __init__(self, mesh_device: MeshDevice):
self.mesh_device = mesh_device

def compose(self, tensor: ttnn.Tensor) -> List["torch.Tensor"]:
def compose(self, tensor: ttnn.Tensor, sub_device_ids: List[ttnn.SubDeviceId] = []) -> List["torch.Tensor"]:
return [
ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor)
ttnn.to_torch(tt_input_tensor, mesh_composer=None, sub_device_ids=sub_device_ids)
for tt_input_tensor in ttnn.get_device_tensors(tensor)
]


Expand Down

0 comments on commit fa868ef

Please sign in to comment.