Skip to content

Commit

Permalink
Move definitions to implementation for core.hpp (#17118)
Browse files Browse the repository at this point in the history
  • Loading branch information
blozano-tt authored Jan 27, 2025
1 parent a5796bc commit dc5def7
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 39 deletions.
40 changes: 40 additions & 0 deletions ttnn/cpp/ttnn/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,48 @@

#include "ttnn/core.hpp"

#include <magic_enum/magic_enum.hpp>

namespace ttnn::core {

bool has_storage_type_of(const ttnn::Tensor& tensor, const ttnn::StorageType& storage_type) {
return tensor.storage_type() == storage_type;
}

std::optional<ttnn::MemoryConfig> get_memory_config(const ttnn::Tensor& tensor) {
if (not tensor.is_allocated() or not is_tensor_on_device_or_multidevice(tensor)) {
return std::nullopt;
}
return tensor.memory_config();
}

void set_printoptions(const std::string& profile) {
tt::tt_metal::tensor_impl::TTNN_TENSOR_PRINT_PROFILE =
magic_enum::enum_cast<tt::tt_metal::tensor_impl::TensorPrintProfile>(profile, [](char lhs, char rhs) {
return std::tolower(lhs) == std::tolower(rhs);
}).value();
}

void segfault_handler(int sig) {
std::cerr << tt::assert::backtrace_to_string() << std::endl;
exit(EXIT_FAILURE);
}

void dump_stack_trace_on_segfault() {
if (std::signal(SIGSEGV, segfault_handler) == SIG_ERR) {
std::cerr << "Error: cannot handle SIGSEGV" << std::endl;
exit(EXIT_FAILURE);
}
}
} // namespace ttnn::core

namespace ttnn {

CoreIDs& CoreIDs::instance() {
static CoreIDs instance;
return instance;
}

std::int64_t CoreIDs::get_python_operation_id() { return python_operation_id.load(); }
void CoreIDs::set_python_operation_id(std::int64_t python_operation_id_) { python_operation_id = python_operation_id_; }
std::int64_t CoreIDs::fetch_and_increment_python_operation_id() { return python_operation_id.fetch_add(1); }
Expand Down
51 changes: 12 additions & 39 deletions ttnn/cpp/ttnn/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

#pragma once
#include <csignal>
#include <cstdint>
#include <optional>
#include <string>

#include <magic_enum/magic_enum.hpp>
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/tensor_impl.hpp" // TTNN_TENSOR_PRINT_PROFILE
#include "ttnn/tensor/types.hpp"
Expand All @@ -28,53 +29,25 @@ namespace ttnn {

namespace core {

inline std::uint32_t pad_to_multiple_of_tile_size(std::uint32_t value, std::uint32_t tile_size) {
return (value + (tile_size - 1)) / tile_size * tile_size;
}

inline bool has_storage_type_of(const ttnn::Tensor& tensor, const ttnn::StorageType& storage_type) {
return tensor.storage_type() == storage_type;
}

inline std::optional<ttnn::MemoryConfig> get_memory_config(const ttnn::Tensor& tensor) {
if (not tensor.is_allocated() or not is_tensor_on_device_or_multidevice(tensor)) {
return std::nullopt;
}
return tensor.memory_config();
}

inline void set_printoptions(const std::string& profile) {
tt::tt_metal::tensor_impl::TTNN_TENSOR_PRINT_PROFILE =
magic_enum::enum_cast<tt::tt_metal::tensor_impl::TensorPrintProfile>(profile, [](char lhs, char rhs) {
return std::tolower(lhs) == std::tolower(rhs);
}).value();
}

inline void segfault_handler(int sig) {
std::cerr << tt::assert::backtrace_to_string() << std::endl;
exit(EXIT_FAILURE);
}

inline void dump_stack_trace_on_segfault() {
if (std::signal(SIGSEGV, segfault_handler) == SIG_ERR) {
std::cerr << "Error: cannot handle SIGSEGV" << std::endl;
exit(EXIT_FAILURE);
}
}
bool has_storage_type_of(const ttnn::Tensor& tensor, const ttnn::StorageType& storage_type);

std::optional<ttnn::MemoryConfig> get_memory_config(const ttnn::Tensor& tensor);

void set_printoptions(const std::string& profile);

void segfault_handler(int sig);

void dump_stack_trace_on_segfault();

} // namespace core

using core::get_memory_config;
using core::has_storage_type_of;
using core::pad_to_multiple_of_tile_size;
using core::set_printoptions;

class CoreIDs {
public:
static CoreIDs& instance() {
static CoreIDs instance;
return instance;
}
static CoreIDs& instance();

std::int64_t get_python_operation_id();
void set_python_operation_id(std::int64_t operation_id);
Expand Down

0 comments on commit dc5def7

Please sign in to comment.