Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use send/recv imple all2all #1351

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 90 additions & 114 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,44 +31,6 @@ const std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
{at::kFloat8_e5m2fnuz, ccl::datatype::uint8},
};

bool computeLengthsAndCheckAndGetFlat(
const std::vector<at::Tensor>& tensors,
std::vector<size_t>& lengths,
at::Tensor& flatTensor,
int64_t& flatLength) {
int64_t groupSize = tensors.size();
auto firstTensor = tensors[0];
int64_t totalSize = 0;
bool isFlat = true;

auto storage = firstTensor.storage();
int64_t firstStorageOffset = firstTensor.storage_offset();

for (int i = 0; i < groupSize; i++) {
auto& curTensor = tensors[i];
int64_t length = curTensor.numel();
lengths[i] = length;
totalSize += length;

if (isFlat &&
(!storage.is_alias_of(curTensor.storage()) ||
curTensor.storage_offset() !=
firstStorageOffset + totalSize - length)) {
isFlat = false;
}
}

flatLength = totalSize;

if (isFlat) {
flatTensor = firstTensor;
} else {
flatTensor = at::empty({totalSize}, firstTensor.options());
}

return isFlat;
}

bool checkSameSize(const std::vector<at::Tensor>& input_tensors) {
for (const auto& input_tensor : input_tensors) {
if (!input_tensors[0].is_same_size(input_tensor)) {
Expand Down Expand Up @@ -1650,13 +1612,32 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall_base(
c10::xpu::XPUCachingAllocator::recordStream(
output.storage().data_ptr(), stream);
auto xcclDataType = getXcclDataType(output.scalar_type());
ccl::alltoall(
input.data_ptr(),
output.data_ptr(),
(size_t)output.numel() / comm.size(),
xcclDataType,
comm,
ccl::create_stream(stream.queue()));
auto xccl_stream = ccl::create_stream(stream.queue());

size_t count = input.numel() / size_;
size_t rankdiff = input.nbytes() / size_;

ccl::group_start();
for (const auto r : c10::irange(rank_)) {
if (count != 0) {
ccl::send(
((char*)input.data_ptr()) + r * rankdiff,
count,
xcclDataType,
r,
comm,
xccl_stream);
ccl::recv(
((char*)output.data_ptr()) + r * rankdiff,
count,
xcclDataType,
r,
comm,
xccl_stream);
}
}
ccl::group_end();

return;
},
OpType::ALLTOALL_BASE,
Expand Down Expand Up @@ -1690,33 +1671,47 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall_base(
at::Tensor& output,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
std::vector<size_t> sendCounts(size_);
std::vector<size_t> recvCounts(size_);
bool inputSplitsEqual = inputSplitSizes.size() == 0;
bool outputSplitsEqual = outputSplitSizes.size() == 0;

size_t inLen = input.numel();
size_t outLen = output.numel();
if (inLen)
inLen /= (inputSplitsEqual ? size_ : input.size(0));
if (outLen)
outLen /= (outputSplitsEqual ? size_ : output.size(0));

for (int i = 0; i < size_; i++) {
sendCounts[i] =
(inputSplitsEqual ? inLen : inputSplitSizes[i] * inLen);
recvCounts[i] =
(outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen);
std::vector<size_t> send_lengths(size_);
std::vector<size_t> recv_lengths(size_);
std::vector<size_t> send_offsets(size_);
std::vector<size_t> recv_offsets(size_);
c10d::computeLengthsAndOffsets(
inputSplitSizes, input, &send_lengths, &send_offsets);
c10d::computeLengthsAndOffsets(
outputSplitSizes, output, &recv_lengths, &recv_offsets);

size_t size = input.element_size();
auto xcclDataType = getXcclDataType(input.scalar_type());
c10::xpu::XPUCachingAllocator::recordStream(
output.storage().data_ptr(), stream);

auto send_offsets_data = send_offsets.data();
auto recv_offsets_data = recv_offsets.data();
auto xccl_stream = ccl::create_stream(stream.queue());

ccl::group_start();
for (const auto r : c10::irange(size_)) {
if (send_lengths[r] != 0) {
ccl::send(
((char*)input.data_ptr()) + send_offsets_data[r] * size,
send_lengths[r],
xcclDataType,
r,
comm,
xccl_stream);
}
if (recv_lengths[r] != 0) {
ccl::recv(
((char*)output.data_ptr()) + recv_offsets_data[r] * size,
recv_lengths[r],
xcclDataType,
r,
comm,
xccl_stream);
}
}
auto xcclDataType = getXcclDataType(output.scalar_type());
ccl::alltoallv(
input.data_ptr(),
sendCounts,
output.data_ptr(),
recvCounts,
xcclDataType,
comm,
ccl::create_stream(stream.queue()));
ccl::group_end();

return;
},
OpType::ALLTOALL_BASE,
Expand Down Expand Up @@ -1764,52 +1759,33 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall(
at::Tensor& /* unused */,
xcclComm_t& comm,
at::xpu::XPUStream& stream) {
c10::OptionalStreamGuard stream_guard(stream.unwrap());
at::Tensor flatInput;
at::Tensor flatOutput;

std::vector<size_t> sendCounts(size_);
std::vector<size_t> recvCounts(size_);

int64_t flatSendCount;
int64_t flatRecvCount;

bool isInputFlat = computeLengthsAndCheckAndGetFlat(
inputTensors, sendCounts, flatInput, flatSendCount);
bool isOutputFlat = computeLengthsAndCheckAndGetFlat(
outputTensors, recvCounts, flatOutput, flatRecvCount);
if (!isInputFlat) {
auto flatInputSplits = flatInput.split_with_sizes(
c10::IntArrayRef((int64_t*)sendCounts.data(), sendCounts.size()),
0);

for (int i = 0; i < size_; i++) {
flatInputSplits[i].copy_(inputTensors[i].view({-1}));
auto xccl_stream = ccl::create_stream(stream.queue());
ccl::group_start();
for (const int r :
c10::irange(static_cast<int>(outputTensors.size()))) {
at::Tensor& input = inputTensors[r];
at::Tensor& output = outputTensors[r];
if (input.numel() != 0) {
ccl::send(
input.data_ptr(),
input.numel(),
getXcclDataType(input.scalar_type()),
r,
comm,
xccl_stream);
}
}

auto xcclDataType = getXcclDataType(flatOutput.scalar_type());
ccl::event ret_evt;
ret_evt = ccl::alltoallv(
flatInput.data_ptr(),
sendCounts,
flatOutput.data_ptr(),
recvCounts,
xcclDataType,
comm,
ccl::create_stream(stream.queue()));

if (!isOutputFlat) {
ret_evt.wait();
auto flatOutputSplits = flatOutput.split_with_sizes(
c10::IntArrayRef((int64_t*)recvCounts.data(), recvCounts.size()),
0);

for (int i = 0; i < size_; i++) {
outputTensors[i].view({-1}).copy_(flatOutputSplits[i]);
if (output.numel() != 0) {
ccl::recv(
output.data_ptr(),
output.numel(),
getXcclDataType(output.scalar_type()),
r,
comm,
xccl_stream);
}
}
stream.synchronize();
ccl::group_end();

return;
},
OpType::ALLTOALL,
Expand Down
68 changes: 68 additions & 0 deletions test/xpu/distributed/test_c10d_ops_xccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,74 @@ def test_send_recv_object_list(self):
dist.recv_object_list(object_list, 0, device=device)
self.assertEqual(object_list[0], 99)

@requires_xccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs")
def test_all_to_all_single(self):
device = self.rank_to_GPU[self.rank][0]
row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2
x = torch.ones(int(row), 5, device=device) * (self.rank + 1)
x.requires_grad = True
y = torch.empty_like(x)
split_sizes = [(i + 1) * (self.rank + 1) for i in range(self.world_size)]
y = torch.distributed.nn.all_to_all_single(
y, x, output_split_sizes=split_sizes, input_split_sizes=split_sizes
)
expected = []
for idx, tensor in enumerate(torch.split(x, split_sizes)):
expected.append(torch.full_like(tensor, (idx + 1)))
expected = torch.cat(expected)
self.assertEqual(y, expected)
z = y.sin().sum()
z.backward()
x_s = ((self.rank + 1) * torch.ones(int(row), 5, device=device)).cos()
self.assertEqual(x.grad, x_s)

@requires_xccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs")
def test_all_to_all_single_unequal_split(self):
device = self.rank_to_GPU[self.rank][0]
in_splits = [i + 1 for i in range(self.world_size)]
out_splits = [self.rank + 1 for _ in range(self.world_size)]
in_tensor = torch.ones([sum(in_splits), self.world_size]) * self.rank
out_tensor = torch.ones([(self.rank + 1) * self.world_size, self.world_size])
expected_tensor = torch.cat(
[
torch.ones([self.rank + 1, self.world_size]) * i
for i in range(self.world_size)
]
)

in_tensor = in_tensor.to(device)
expected_tensor = expected_tensor.to(device)
out_tensor = out_tensor.to(device)
dist.all_to_all_single(out_tensor, in_tensor, out_splits, in_splits)
self.assertEqual(out_tensor, expected_tensor)

@requires_xccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs")
def test_all_to_all(self, dtype=torch.float):
device = self.rank_to_GPU[self.rank][0]
in_splits = [i + 1 for i in range(self.world_size)]
in_tensors = [
torch.ones([in_splits[i], self.world_size], dtype=dtype) * self.rank
for i in range(self.world_size)
]
out_tensors = [
torch.ones([(self.rank + 1), self.world_size], dtype=dtype)
for _ in range(self.world_size)
]
expected_tensors = [
torch.ones([self.rank + 1, self.world_size], dtype=dtype) * i
for i in range(self.world_size)
]

in_tensors = [t.to(device) for t in in_tensors]
expected_tensors = [t.to(device) for t in expected_tensors]
out_tensors = [t.to(device) for t in out_tensors]
dist.all_to_all(out_tensors, in_tensors)
for t1, t2 in zip(out_tensors, expected_tensors):
self.assertEqual(t1, t2)


if __name__ == "__main__":
rank = int(os.getenv("RANK", -1))
Expand Down
Loading