Skip to content

Commit

Permalink
Fix mark_sharding logic
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochengji committed Jan 15, 2025
1 parent 1c89675 commit a28b637
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ python3 "$TEST_CDIR/scan/test_scan_layers.py"
run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
python3 "$TEST_CDIR/test_pallas.py" -v
python3 "$TEST_CDIR/test_pallas_spmd.py"
XLA_DISABLE_FUNCTIONALIZATION=1 python3 "$TEST_CDIR/test_pallas_spmd.py"
python3 "$TEST_CDIR/test_tpu_paged_attention_kernel.py"
python3 "$TEST_CDIR/test_input_output_aliases.py"
python3 "$TEST_CDIR/test_gmm.py"
Expand Down
21 changes: 19 additions & 2 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,25 @@ void custom_sharding_(
const XLATensorPtr& input,
const std::shared_ptr<XLATensor::ShardingSpec>& sharding_spec,
const CustomSharding::Type& type) {
input->SetInPlaceIrValue(torch_xla::MakeNode<CustomSharding>(
input->GetIrValue(), input->shape().get(), type));
torch::lazy::NodePtr customShardingNode = torch_xla::MakeNode<CustomSharding>(
input->GetIrValue(), input->shape().get(), type);
XlaNode* xla_node = dynamic_cast<XlaNode*>(customShardingNode.get());
// Always call `SetSharding` to ensure the `CustomSharding` op has the correct
// sharding, especially if a view is updated afterward. Updating a view can
// modify the IR, potentially leading to the sharding being applied to the
// updated view instead of the original `CustomSharding` op.

// For example, consider the following IR:
// ```
// x0 = custom_sharding(input)
// x1 = view_update(x0)
// ```
// In this case, we want to ensure the sharding is applied to `x0`, not `x1`.

// While this solution may add a sharding spec to non-CustomSharding ops like
// `x1`, the XLA compiler will safely ignore it.
xla_node->SetSharding(sharding_spec->sharding, 0);
input->SetInPlaceIrValue(customShardingNode);
input->SetShardingSpec(*sharding_spec);
}

Expand Down

0 comments on commit a28b637

Please sign in to comment.