Skip to content

Commit

Permalink
pin update (#8559)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Liu <[email protected]>
  • Loading branch information
lsy323 and Siyuan Liu authored Jan 13, 2025
1 parent 2fdf721 commit 1c89675
Show file tree
Hide file tree
Showing 6 changed files with 5 additions and 7 deletions.
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ new_local_repository(
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.

xla_hash = 'd28bfbdc366627c9ac9f57fcaa512ff04de19d6f'
xla_hash = '8d06f3680ad046ea44f8e7159f52c728bb66c069'

http_archive(
name = "xla",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
base_dir = os.path.dirname(os.path.abspath(__file__))

USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax
_date = '20250106'
_date = '20250113'
_libtpu_version = f'0.0.8'
_jax_version = f'0.4.39'
_jaxlib_version = f'0.4.39'
Expand Down
2 changes: 0 additions & 2 deletions test/cpp/test_aten_xla_tensor_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,6 @@ TEST_F(AtenXlaTensorTest, TestLinalgVectorNormInDimsKeepDtype) {
}

TEST_F(AtenXlaTensorTest, TestLinalgEigh) {
// TODO: Broken by XLA pin update on 20250106.
GTEST_SKIP();
// Hardcode the test input to avoid numerical instability from randomness,
// which is a problem in eigenvalue decomposition.
auto complex64 = [](float real, float imag) {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {

pack->shape = std::vector<int64_t>(pjrt_buffer->dimensions().begin(),
pjrt_buffer->dimensions().end());
xla::Layout xla_layout = xla::GetXlaLayoutUnsafe(pjrt_buffer->layout());
xla::Layout xla_layout = pjrt_buffer->layout()->xla_layout();
pack->strides = StridesForShape(pjrt_buffer->element_type(),
pjrt_buffer->dimensions(), xla_layout);
dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ cc_library(
"@xla//xla:literal_util",
"@xla//xla/client:xla_computation",
"@xla//xla/hlo/ir:hlo",
"@xla//xla/pjrt:pjrt_client",
],
)

Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ std::unordered_map<int, int> build_index_map(
xla::Shape host_output_shape(xla::PjRtBuffer* buffer) {
xla::Shape shape = xla::ShapeUtil::MakeShape(
buffer->element_type(), buffer->logical_dimensions().value());
*shape.mutable_layout() = xla::GetXlaLayoutUnsafe(buffer->layout());

*shape.mutable_layout() = buffer->layout()->xla_layout();
return xla::ShapeUtil::DeviceShapeToHostShape(shape);
}

Expand Down

0 comments on commit 1c89675

Please sign in to comment.