Skip to content

Commit

Permalink
Add ENV USE_ONEMKL to control whether to build with ONEMKL or not
Browse files Browse the repository at this point in the history
  • Loading branch information
CuiYifeng committed Jan 14, 2025
1 parent 86388a1 commit 65a0071
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 4 deletions.
12 changes: 12 additions & 0 deletions cmake/ONEMKL.cmake
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
option(USE_ONEMKL "Build with ONEMKL XPU support" ON)

if(DEFINED ENV{USE_ONEMKL})
set(USE_ONEMKL $ENV{USE_ONEMKL})
endif()

message(STATUS "USE_ONEMKL is set to ${USE_ONEMKL}")

if(NOT USE_ONEMKL)
return()
endif()

find_package(ONEMKL)
if(NOT ONEMKL_FOUND)
message(FATAL_ERROR "Can NOT find ONEMKL cmake helpers module!")
Expand Down
20 changes: 18 additions & 2 deletions src/ATen/native/xpu/SpectralOps.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <ATen/native/Resize.h>
#if defined(USE_ONEMKL)
#include <ATen/native/xpu/mkl/SpectralOps.h>
#include <comm/xpu_aten.h>
#else
#include <ATen/ops/_fft_c2c_native.h>
#endif // USE_ONEMKL

namespace at::native {

Expand All @@ -11,7 +13,13 @@ Tensor _fft_c2c_xpu(
bool forward) {
TORCH_CHECK(self.is_complex());

#if defined(USE_ONEMKL)
return native::xpu::_fft_c2c_mkl(self, dim, normalization, forward);
#else
Tensor out_cpu = native::_fft_c2c_mkl(
self.to(Device(at::kCPU)), dim, normalization, forward);
return out_cpu.to(Device(at::kXPU));
#endif // USE_ONEMKL
}

Tensor& _fft_c2c_xpu_out(
Expand All @@ -22,7 +30,15 @@ Tensor& _fft_c2c_xpu_out(
Tensor& out) {
TORCH_CHECK(self.is_complex());

#if defined(USE_ONEMKL)
return native::xpu::_fft_c2c_mkl_out(self, dim, normalization, forward, out);
#else
Tensor out_cpu = out.to(Device(at::kCPU));
native::_fft_c2c_mkl_out(
self.to(Device(at::kCPU)), dim, normalization, forward, out_cpu);
out.copy_(out_cpu);
return out;
#endif // USE_ONEMKL
}

} // namespace at::native
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/mkl/SpectralOps.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#if defined(USE_ONEMKL)
#include <ATen/native/Resize.h>
#include <ATen/native/SpectralOpsUtils.h>
#include <ATen/native/xpu/mkl/SpectralOps.h>
Expand Down Expand Up @@ -398,3 +399,4 @@ Tensor& _fft_c2c_mkl_out(
}

} // namespace at::native::xpu
#endif // USE_ONEMKL
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/mkl/SpectralOps.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <ATen/core/Tensor.h>

namespace at::native::xpu {

TORCH_XPU_API Tensor _fft_c2c_mkl(
Expand Down
7 changes: 5 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,8 @@ if(CLANG_FORMAT)
add_dependencies(torch_xpu_ops CL_FORMAT_CSRCS)
endif()

target_include_directories(torch_xpu_ops PUBLIC ${TORCH_XPU_OPS_ONEMKL_INCLUDE_DIR})
target_link_libraries(torch_xpu_ops PUBLIC ${TORCH_XPU_OPS_ONEMKL_LIBRARIES})
if(USE_ONEMKL)
target_compile_options(torch_xpu_ops PRIVATE "-DUSE_ONEMKL")
target_include_directories(torch_xpu_ops PUBLIC ${TORCH_XPU_OPS_ONEMKL_INCLUDE_DIR})
target_link_libraries(torch_xpu_ops PUBLIC ${TORCH_XPU_OPS_ONEMKL_LIBRARIES})
endif()

0 comments on commit 65a0071

Please sign in to comment.