Skip to content

Commit

Permalink
#17477: Move SmallVector and ShapeBase to re-use within Metal (#17669)
Browse files Browse the repository at this point in the history
### Ticket
#17477

### Problem description
TT-distributed requires ND shapes in Metal. Instead of having our own,
moving `ShapeBase` and `SmallVector` into Metal.

### What's changed
* Moved `ShapeBase` and `SmallVector` into Metal. `SmallVector` really
should be part of our "stl" library, so I put it there for now. I used
`ttsl` namespace - my goal is to replace `tt::stl` as it is shorter and
easier to type, also avoids confusion if anyone attempts to use `stl::`.

### Checklist
- [x] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/13190528665)
- [X] New/Existing tests provide coverage for changes
  • Loading branch information
omilyutin-tt authored Feb 7, 2025
1 parent 8e4efa2 commit b552fb8
Show file tree
Hide file tree
Showing 23 changed files with 57 additions and 49 deletions.
3 changes: 3 additions & 0 deletions dependencies/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ CPMAddPackage(
add_library(span INTERFACE)
target_link_libraries(span INTERFACE Boost::core)

add_library(small_vector INTERFACE)
target_link_libraries(small_vector INTERFACE Boost::container)

############################################################################################################################
# yaml-cpp
############################################################################################################################
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ target_link_libraries(
magic_enum
fmt::fmt-header-only
span
small_vector
)

if(TT_METAL_BUILD_TESTS)
Expand Down
1 change: 1 addition & 0 deletions tests/tt_metal/tt_metal/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ set(UNIT_TESTS_API_SRC
${CMAKE_CURRENT_SOURCE_DIR}/test_noc.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_runtime_args.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_semaphores.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_shape_base.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_sharded_l1_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_simple_dram_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_simple_l1_buffer.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <exception>
#include "gtest/gtest.h"

#include "ttnn/tensor/shape/shape_base.hpp"
#include <tt-metalium/shape_base.hpp>

TEST(TensorShapeBaseTests, General4D) {
tt::tt_metal::ShapeBase vec({20, 30, 40, 50});
Expand Down
1 change: 0 additions & 1 deletion tests/ttnn/unit_tests/gtests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ set(TTNN_TENSOR_UNIT_TESTS_SRC
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_distributed_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_mesh_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_partition.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_shape_base.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_tensor_sharding.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_vector_conversion.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_xtensor_conversion.cpp
Expand Down
1 change: 1 addition & 0 deletions tt_metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ target_link_libraries(
magic_enum
fmt::fmt-header-only
span
small_vector
TracyClient
nlohmann_json::nlohmann_json
TT::Metalium::HostDevCommon
Expand Down
1 change: 1 addition & 0 deletions tt_metal/api/tt-metalium/shape2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

namespace tt::tt_metal {

// Simplified 2D shape for cases that fundamentally require 2 dimensions (e.g. core grid).
class Shape2D final {
public:
Shape2D(std::size_t height, std::size_t width);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace tt::tt_metal {
// Container wrapper that allows negative indexing
class ShapeBase {
public:
using Container = SmallVector<uint32_t>;
using Container = tt::stl::SmallVector<uint32_t>;

ShapeBase() { init(); };
explicit ShapeBase(const Container& shape) : value_(shape) { init(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,9 @@

#include <boost/container/small_vector.hpp>

#include <tt-metalium/reflection.hpp>
#include "reflection.hpp"

#if TTNN_WITH_PYTHON_BINDINGS
#include <pybind11/stl.h>
#endif

namespace tt::tt_metal {
namespace tt::stl {

static constexpr size_t SMALL_VECTOR_SIZE = 8;

Expand All @@ -35,15 +31,16 @@ std::ostream& operator<<(std::ostream& os, const SmallVector<T, PREALLOCATED_SIZ
return os;
}

} // namespace tt::tt_metal
} // namespace tt::stl

// TODO: remove this.
namespace ttnn {
using tt::tt_metal::SmallVector;
using tt::stl::SmallVector;
}

template <typename T, size_t PREALLOCATED_SIZE>
struct std::hash<tt::tt_metal::SmallVector<T, PREALLOCATED_SIZE>> {
size_t operator()(const ttnn::SmallVector<T, PREALLOCATED_SIZE>& vec) const noexcept {
struct std::hash<tt::stl::SmallVector<T, PREALLOCATED_SIZE>> {
size_t operator()(const tt::stl::SmallVector<T, PREALLOCATED_SIZE>& vec) const noexcept {
size_t hash = 0;
for (const auto& element : vec) {
hash = tt::stl::hash::detail::hash_objects(hash, element);
Expand All @@ -53,23 +50,13 @@ struct std::hash<tt::tt_metal::SmallVector<T, PREALLOCATED_SIZE>> {
};

template <typename T, size_t PREALLOCATED_SIZE>
struct fmt::formatter<tt::tt_metal::SmallVector<T, PREALLOCATED_SIZE>> {
struct fmt::formatter<tt::stl::SmallVector<T, PREALLOCATED_SIZE>> {
constexpr auto parse(format_parse_context& ctx) -> format_parse_context::iterator { return ctx.end(); }

auto format(const tt::tt_metal::SmallVector<T, PREALLOCATED_SIZE>& vector, format_context& ctx) const
auto format(const tt::stl::SmallVector<T, PREALLOCATED_SIZE>& vector, format_context& ctx) const
-> format_context::iterator {
std::stringstream ss;
ss << vector;
return fmt::format_to(ctx.out(), "{}", ss.str());
}
};

#if TTNN_WITH_PYTHON_BINDINGS
namespace PYBIND11_NAMESPACE {
namespace detail {
template <typename T, size_t PREALLOCATED_SIZE>
struct type_caster<tt::tt_metal::SmallVector<T, PREALLOCATED_SIZE>>
: list_caster<tt::tt_metal::SmallVector<T, PREALLOCATED_SIZE>, T> {};
} // namespace detail
} // namespace PYBIND11_NAMESPACE
#endif
2 changes: 2 additions & 0 deletions tt_metal/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ set(COMMON_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/core_descriptor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal_soc_descriptor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/shape2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/shape_base.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tt_backend_api_types.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/work_split.cpp
Expand All @@ -20,6 +21,7 @@ target_link_libraries(
magic_enum
fmt::fmt-header-only
span
small_vector
Metalium::Metal::STL
umd::Firmware
umd::device
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "assert.hpp"
#include "shape_base.hpp"
#include <stdexcept>
#include "fmt/color.h"
#include <tt-metalium/assert.hpp>

namespace tt::tt_metal {

Expand Down
12 changes: 0 additions & 12 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,6 @@ set(TTNN_PUBLIC_INCLUDE_DIRS
set(TTNN_PUBLIC_LINK_LIBRARIES
metal_common_libs
Metalium::Metal
Boost::container
xtensor
xtensor-blas
xtl
Expand Down Expand Up @@ -758,11 +757,6 @@ function(add_ttnn_sublibrary SUBLIBRARY_NAME)
add_library(${SUBLIBRARY_NAME} OBJECT ${ARGN})
endif()
TT_ENABLE_UNITY_BUILD(${SUBLIBRARY_NAME})
if(WITH_PYTHON_BINDINGS)
target_compile_definitions(${SUBLIBRARY_NAME} PUBLIC TTNN_WITH_PYTHON_BINDINGS=1)
else()
target_compile_definitions(${SUBLIBRARY_NAME} PUBLIC TTNN_WITH_PYTHON_BINDINGS=0)
endif()
target_include_directories(${SUBLIBRARY_NAME} PUBLIC ${TTNN_PUBLIC_INCLUDE_DIRS})
target_link_libraries(${SUBLIBRARY_NAME} PUBLIC ${TTNN_PUBLIC_LINK_LIBRARIES})
target_link_directories(${SUBLIBRARY_NAME} PUBLIC ${TTNN_PUBLIC_LINK_DIRS})
Expand Down Expand Up @@ -826,12 +820,6 @@ target_compile_options(
-fno-var-tracking
)

if(WITH_PYTHON_BINDINGS)
target_compile_definitions(ttnn PUBLIC TTNN_WITH_PYTHON_BINDINGS=1)
else()
target_compile_definitions(ttnn PUBLIC TTNN_WITH_PYTHON_BINDINGS=0)
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
target_compile_definitions(ttnn PUBLIC DISABLE_NAMESPACE_STATIC_ASSERT)
endif()
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/decorators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <type_traits>

#include "ttnn/decorators.hpp"
#include "small_vector_caster.hpp" // NOLINT - for pybind11 SmallVector binding support.
#include "ttnn/types.hpp"

namespace py = pybind11;
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/pybind11/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "small_vector_caster.hpp" // NOLINT - for pybind11 SmallVector binding support.
#include <tt-metalium/persistent_kernel_cache.hpp>
#include <tt-metalium/compilation_reporter.hpp>
#include <tt-metalium/memory_reporter.hpp>
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <chrono>
#include <memory>

#include "small_vector_caster.hpp" // NOLINT - for pybind11 SmallVector binding support.
#include "ttnn/tensor/tensor.hpp"
#include <tt-metalium/graph_tracking.hpp>
#include <tt-metalium/overloaded.hpp>
Expand Down
19 changes: 19 additions & 0 deletions ttnn/cpp/pybind11/small_vector_caster.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <tt-metalium/small_vector.hpp>

namespace PYBIND11_NAMESPACE {
namespace detail {
template <typename T, size_t PREALLOCATED_SIZE>
struct type_caster<ttnn::SmallVector<T, PREALLOCATED_SIZE>> : list_caster<ttnn::SmallVector<T, PREALLOCATED_SIZE>, T> {
};
} // namespace detail
} // namespace PYBIND11_NAMESPACE
3 changes: 3 additions & 0 deletions ttnn/cpp/pybind11/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <tt-metalium/small_vector.hpp>

#include "export_enum.hpp"
#include "small_vector_caster.hpp" // NOLINT - for pybind11 SmallVector binding support.
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/types.hpp"
#include "ttnn/operations/data_movement/bcast/bcast_types.hpp"
Expand Down
1 change: 0 additions & 1 deletion ttnn/cpp/ttnn/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ set(TENSOR_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/tensor_spec.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/serialization.cpp
${CMAKE_CURRENT_SOURCE_DIR}/shape/shape_base.cpp
${CMAKE_CURRENT_SOURCE_DIR}/shape/shape.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/alignment.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/page_config.cpp
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/tensor/layout/alignment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace tt::tt_metal {

bool Alignment::operator==(const Alignment& other) const = default;

bool Alignment::operator==(const SmallVector<uint32_t>& other) const { return this->value_ == other; }
bool Alignment::operator==(const tt::stl::SmallVector<uint32_t>& other) const { return this->value_ == other; }

std::ostream& operator<<(std::ostream& os, const tt::tt_metal::Alignment& alignment) {
os << "Alignment([";
Expand Down
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/tensor/layout/alignment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

#pragma once

#include "ttnn/tensor/shape/shape_base.hpp"
#include "ttnn/tensor/shape/small_vector.hpp"
#include <tt-metalium/small_vector.hpp>
#include <tt-metalium/shape_base.hpp>

namespace tt::tt_metal {

Expand All @@ -26,7 +26,7 @@ class Alignment final : protected ShapeBase {
}

bool operator==(const Alignment& other) const;
bool operator==(const SmallVector<uint32_t>& other) const;
bool operator==(const tt::stl::SmallVector<uint32_t>& other) const;

// Needed for reflect / fmt
static constexpr auto attribute_names = std::forward_as_tuple("value");
Expand Down
7 changes: 4 additions & 3 deletions ttnn/cpp/ttnn/tensor/shape/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

#include <numeric>
#include <ostream>
#include "ttnn/tensor/shape/small_vector.hpp"

#include <tt-metalium/assert.hpp>
#include <tt-metalium/small_vector.hpp>

namespace tt::tt_metal {

bool Shape::operator==(const Shape& other) const = default;

bool Shape::operator==(const SmallVector<uint32_t>& other) const { return this->value_ == other; }
bool Shape::operator==(const tt::stl::SmallVector<uint32_t>& other) const { return this->value_ == other; }

size_t Shape::rank() const { return this->size(); }

Expand All @@ -29,7 +30,7 @@ std::array<uint32_t, 4> Shape::to_array_4D() const {
}

Shape Shape::to_rank(size_t new_rank) const {
SmallVector<uint32_t> new_shape(new_rank, 1);
tt::stl::SmallVector<uint32_t> new_shape(new_rank, 1);

int cur_idx = static_cast<int>(rank()) - 1;
int new_idx = static_cast<int>(new_rank) - 1;
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/tensor/shape/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#pragma once

#include "shape_base.hpp"
#include <tt-metalium/shape_base.hpp>

namespace tt::tt_metal {

Expand Down
3 changes: 2 additions & 1 deletion ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

#pragma once

#include "ttnn/tensor/shape/small_vector.hpp"
#include <tt-metalium/small_vector.hpp>

#include "ttnn/tensor/tensor.hpp"
#include <ttnn/tensor/xtensor/xtensor_all_includes.hpp>

Expand Down

0 comments on commit b552fb8

Please sign in to comment.