Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change create_tt_tensor_from_py_data to use from_vector #16999

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
35ba041
Change create_tt_tensor_from_py_data to use from_vector
jjiangTT Jan 22, 2025
7ac5be0
removed switch case default, use bfloat16 for bfp_b types, shift bfp_…
jjiangTT Jan 23, 2025
cec8ec8
add encode_tensor_data to create_owned_tensor_from_row_major_data
jjiangTT Jan 23, 2025
4352603
removed graph printng and parameterized tests splitting out bfp_b
jjiangTT Jan 23, 2025
3360602
Remove extraneous comments and double declarations, change data_ptr t…
jjiangTT Jan 23, 2025
daad2e1
clean up create_owned_tensor_from_row_major_data
jjiangTT Jan 23, 2025
b332630
Merge branch 'main' into jjiang/16837-tensor_creation_and_conversion
jjiangTT Jan 23, 2025
8c6802b
additional minor formatting fixes
jjiangTT Jan 23, 2025
7b5d3b0
Merge branch 'jjiang/16837-tensor_creation_and_conversion' of https:/…
jjiangTT Jan 23, 2025
e9f5bf4
Added shape conversion testing, borrow testing for bfp_b types, and s…
jjiangTT Jan 24, 2025
fecf95c
Change create_tt_tensor_from_py_data to use from_vector
jjiangTT Jan 22, 2025
c0873cb
removed switch case default, use bfloat16 for bfp_b types, shift bfp_…
jjiangTT Jan 23, 2025
c6c387c
add encode_tensor_data to create_owned_tensor_from_row_major_data
jjiangTT Jan 23, 2025
8454b66
removed graph printng and parameterized tests splitting out bfp_b
jjiangTT Jan 23, 2025
c837112
Remove extraneous comments and double declarations, change data_ptr t…
jjiangTT Jan 23, 2025
42d8476
clean up create_owned_tensor_from_row_major_data
jjiangTT Jan 23, 2025
74a4412
additional minor formatting fixes
jjiangTT Jan 23, 2025
ba5de4c
Added shape conversion testing, borrow testing for bfp_b types, and s…
jjiangTT Jan 24, 2025
68cf514
Merge branch 'jjiang/16837-tensor_creation_and_conversion' of https:/…
jjiangTT Jan 24, 2025
ea0c02d
Fix layout error on validstorage test
jjiangTT Jan 24, 2025
b1b5f10
move validstorage test into test_convert_python_tensor.py, add shard …
jjiangTT Jan 27, 2025
482b0d2
fix shard bounding
jjiangTT Jan 27, 2025
cfa67a5
Merge branch 'main' into jjiang/16837-tensor_creation_and_conversion
jjiangTT Jan 28, 2025
f785091
remove unnecessary asserts and deprecated create_tensor method
jjiangTT Jan 28, 2025
8d4597b
fix test_convert_python_tensor imports, move convert_python to unit_t…
jjiangTT Jan 28, 2025
17ecda7
Remove todo for type checking logic
jjiangTT Jan 28, 2025
ff70a26
Merge branch 'main' into jjiang/16837-tensor_creation_and_conversion
jjiangTT Jan 28, 2025
cc8f9e0
fix graph tracing errors in test_convert_python, fix extraneous tt_fa…
jjiangTT Jan 28, 2025
c52e47f
fix pytorch type and comparison errors
jjiangTT Jan 29, 2025
d8c6c40
remove extraneous includes
jjiangTT Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tests/ttnn/unit_tests/test_convert_python_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pathlib
import pytest

import torch

import ttnn


@pytest.mark.parametrize("size", [64])
@pytest.mark.parametrize("mode", [ttnn.graph.RunMode.NO_DISPATCH, ttnn.graph.RunMode.NORMAL])
@pytest.mark.parametrize("dtype", [torch.int32, torch.float, torch.bfloat16])
def test_convert_python_tensor(device, size, mode, dtype):
torch.manual_seed(0)

ttnn.graph.begin_graph_capture(mode)
torch_input_tensor = torch.rand((size,), (dtype))
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.to_torch(input_tensor, torch_rank=1)
captured_graph = ttnn.graph.end_graph_capture()
calltrace = ttnn.graph.extract_calltrace(captured_graph)

assert output_tensor == input_tensor
assert "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor" in calltrace
assert captured_graph[0]["node_type"] == "capture_start"
assert captured_graph[1]["node_type"] == "function_start"
assert captured_graph[1]["params"]["name"] == "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor"
assert captured_graph[-2]["node_type"] == "buffer_deallocate"
assert captured_graph[-1]["node_type"] == "capture_end"


@pytest.mark.parametrize("size", [64])
@pytest.mark.parametrize("mode", [ttnn.graph.RunMode.NO_DISPATCH, ttnn.graph.RunMode.NORMAL])
@pytest.mark.parametrize("dtype", [ttnn.bfloat4_b, ttnn.bfloat8_b])
def test_convert_python_tensor_bfp_b(device, size, mode, dtype):
torch.manual_seed(0)

ttnn.graph.begin_graph_capture(mode)
torch_input_tensor = torch.rand((size,), torch.float)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device, dtype=(dtype))
output_tensor = ttnn.to_torch(input_tensor, torch_rank=1)
captured_graph = ttnn.graph.end_graph_capture()
calltrace = ttnn.graph.extract_calltrace(captured_graph)

assert output_tensor == input_tensor
assert "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor" in calltrace
assert captured_graph[0]["node_type"] == "capture_start"
assert captured_graph[1]["node_type"] == "function_start"
assert captured_graph[1]["params"]["name"] == "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor"
assert captured_graph[-2]["node_type"] == "buffer_deallocate"
assert captured_graph[-1]["node_type"] == "capture_end"
127 changes: 56 additions & 71 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,45 @@ Tensor create_owned_tensor(T* data_ptr, const ttnn::TensorSpec& tensor_spec) {
return Tensor(std::move(storage), tensor_spec);
}

template <typename T>
Tensor create_typed_tt_tensor_from_py_data(
std::size_t py_data_ptr,
const TensorSpec& tensor_spec,
IDevice* device,
const std::function<void()>& on_creation_callback,
const std::function<void()>& on_destruction_callback,
const bool force_disable_borrow) {
const bool requires_padding = tensor_spec.logical_2d_shape() != tensor_spec.physical_shape();
const bool requires_tilization = layout != Layout::ROW_MAJOR;
const bool enable_borrow = !requires_padding and !requires_tilization and !force_disable_borrow;

TT_FATAL(
jjiangTT marked this conversation as resolved.
Show resolved Hide resolved
!tensor_spec.memory_config().is_sharded() or tensor_spec.memory_config().shard_spec.has_value(),
"Sharded tensors must have a shard spec when converting to tt tensors!");

// Use template type for generic function - TODO find better way, maybe decltype or variants w/ array or map?
jjiangTT marked this conversation as resolved.
Show resolved Hide resolved
auto data_ptr = reinterpret_cast<T*>(py_data_ptr);
jjiangTT marked this conversation as resolved.
Show resolved Hide resolved

std::size_t num_elements = tensor_spec.logical_shape().volume();

// never enable_borrow for bfloat8 and bfloat4 since they're tt specific types
jjiangTT marked this conversation as resolved.
Show resolved Hide resolved
if (enable_borrow and
!(tensor_spec.data_type == DataType::BFLOAT8_B || tensor_spec.data_type == DataType::BFLOAT4_B)) {
auto storage = BorrowedStorage(
borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
return Tensor(std::move(storage), tensor_spec);
} else {
std::size_t num_elements = tensor_spec.logical_shape().volume();
jjiangTT marked this conversation as resolved.
Show resolved Hide resolved
auto logical_data = std::vector<T>(data_ptr, data_ptr + num_elements);

// Abstract away handling by calling from_vector which calls from_span which handles bfloats
jjiangTT marked this conversation as resolved.
Show resolved Hide resolved
return Tensor::from_vector(
std::move(logical_data),
jjiangTT marked this conversation as resolved.
Show resolved Hide resolved
tensor_spec,
device == nullptr ? std::nullopt : std::optional<ttnn::AnyDevice>(device));
}
}

Tensor create_tt_tensor_from_py_data(
std::size_t py_data_ptr,
const TensorSpec& tensor_spec,
Expand All @@ -94,97 +133,43 @@ Tensor create_tt_tensor_from_py_data(
const std::function<void()>& on_destruction_callback) {
auto layout = tensor_spec.layout();

const bool requires_padding = tensor_spec.logical_2d_shape() != tensor_spec.physical_shape();
const bool requires_tilization = layout != Layout::ROW_MAJOR;
const bool enable_borrow = !requires_padding and !requires_tilization and !force_disable_borrow;

auto data_type = tensor_spec.data_type();
std::size_t num_elements = tensor_spec.logical_shape().volume();
switch (data_type) {
case DataType::UINT8: {
auto data_ptr = reinterpret_cast<uint8_t*>(py_data_ptr);
if (enable_borrow) {
auto storage = BorrowedStorage(
borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
return Tensor(std::move(storage), tensor_spec);
} else {
return create_owned_tensor(data_ptr, tensor_spec);
}
return create_typed_tt_tensor_from_py_data<uint8_t>(
py_data_ptr, tensor_spec, device, on_creation_callback, on_destruction_callback, force_disable_borrow);
}
case DataType::UINT16: {
auto data_ptr = reinterpret_cast<uint16_t*>(py_data_ptr);
if (enable_borrow) {
auto storage = BorrowedStorage(
borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
return Tensor(std::move(storage), tensor_spec);
} else {
return create_owned_tensor(data_ptr, tensor_spec);
}
return create_typed_tt_tensor_from_py_data<uint16_t>(
py_data_ptr, tensor_spec, device, on_creation_callback, on_destruction_callback, force_disable_borrow);
}
case DataType::INT32: {
auto data_ptr = reinterpret_cast<int32_t*>(py_data_ptr);
if (enable_borrow) {
auto storage = BorrowedStorage(
borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
return Tensor(std::move(storage), tensor_spec);
} else {
return create_owned_tensor(data_ptr, tensor_spec);
}
return create_typed_tt_tensor_from_py_data<int32_t>(
py_data_ptr, tensor_spec, device, on_creation_callback, on_destruction_callback, force_disable_borrow);
}
case DataType::UINT32: {
auto data_ptr = reinterpret_cast<uint32_t*>(py_data_ptr);
if (enable_borrow) {
auto storage = BorrowedStorage(
borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
return Tensor(std::move(storage), tensor_spec);
} else {
return create_owned_tensor(data_ptr, tensor_spec);
}
return create_typed_tt_tensor_from_py_data<uint32_t>(
py_data_ptr, tensor_spec, device, on_creation_callback, on_destruction_callback, force_disable_borrow);
}
case DataType::FLOAT32: {
auto data_ptr = reinterpret_cast<float*>(py_data_ptr);
if (enable_borrow) {
auto storage = BorrowedStorage(
borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
return Tensor(std::move(storage), tensor_spec);
} else {
return create_owned_tensor(data_ptr, tensor_spec);
}
return create_typed_tt_tensor_from_py_data<float>(
py_data_ptr, tensor_spec, device, on_creation_callback, on_destruction_callback, force_disable_borrow);
}
// TODO: This is not supported for numpy
case DataType::BFLOAT16: {
auto data_ptr = reinterpret_cast<::bfloat16*>(py_data_ptr);
if (enable_borrow) {
auto storage = BorrowedStorage(
borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback);
return Tensor(std::move(storage), tensor_spec);
} else {
return create_owned_tensor(data_ptr, tensor_spec);
jjiangTT marked this conversation as resolved.
Show resolved Hide resolved
}
return create_typed_tt_tensor_from_py_data<bfloat16>(
py_data_ptr, tensor_spec, device, on_creation_callback, on_destruction_callback, force_disable_borrow);
}
case DataType::BFLOAT8_B:
case DataType::BFLOAT4_B: {
auto data_ptr = reinterpret_cast<float*>(py_data_ptr);
auto float_tensor_spec = TensorSpec(
tensor_spec.logical_shape(),
TensorLayout(DataType::FLOAT32, tensor_spec.page_config(), tensor_spec.memory_config()));
auto float_tensor = create_owned_tensor(data_ptr, float_tensor_spec);

auto tile = tensor_spec.tensor_layout().get_page_config().get_tile();
auto output_float_data = owned_buffer::get_as<float>(float_tensor).get();
auto output_packed_data = data_type == DataType::BFLOAT8_B
? pack_fp32_vec_as_bfp8_tiles(
output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile)
: pack_fp32_vec_as_bfp4_tiles(
output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile);
auto output_buffer = owned_buffer::create<uint32_t>(std::move(output_packed_data));
return Tensor(std::move(OwnedStorage{std::move(output_buffer)}), tensor_spec);
}
default: {
TT_THROW("Unsupported DataType: {}", data_type);
break;
return create_typed_tt_tensor_from_py_data<bfloat16>(
py_data_ptr, tensor_spec, device, on_creation_callback, on_destruction_callback, force_disable_borrow);
}
}

// remove default case in switch
jjiangTT marked this conversation as resolved.
Show resolved Hide resolved
TT_THROW("Unsupported DataType: {}", data_type);
}

Tensor convert_python_tensor_to_tt_tensor(
Expand Down
4 changes: 3 additions & 1 deletion ttnn/cpp/ttnn/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ Tensor create_owned_tensor_from_row_major_data(
spec.logical_shape(),
TensorLayout(spec.data_type(), PageConfig(Layout::ROW_MAJOR, spec.tile()), MemoryConfig{}));

Tensor output(OwnedStorage{owned_buffer::create(std::move(data))}, result_cpu_spec);
auto physical_data = tensor_impl::encode_tensor_data(std::move(logical_data), tensor_spec);

Tensor output(OwnedStorage{owned_buffer::create(std::move(physical_data))}, result_cpu_spec);

if (spec.layout() == Layout::TILE) {
jjiangTT marked this conversation as resolved.
Show resolved Hide resolved
// TODO: whenever possible, perform tiliziation on device.
Expand Down
Loading