diff --git a/cmake/ONEMKL.cmake b/cmake/ONEMKL.cmake index 2d40ebccf..73eb95177 100644 --- a/cmake/ONEMKL.cmake +++ b/cmake/ONEMKL.cmake @@ -1,3 +1,15 @@ +option(USE_ONEMKL "Build with ONEMKL XPU support" ON) + +if(DEFINED ENV{USE_ONEMKL}) + set(USE_ONEMKL $ENV{USE_ONEMKL}) +endif() + +message(STATUS "USE_ONEMKL is set to ${USE_ONEMKL}") + +if(NOT USE_ONEMKL) + return() +endif() + find_package(ONEMKL) if(NOT ONEMKL_FOUND) message(FATAL_ERROR "Can NOT find ONEMKL cmake helpers module!") diff --git a/src/ATen/native/xpu/SpectralOps.cpp b/src/ATen/native/xpu/SpectralOps.cpp index af82394f1..311ecf140 100644 --- a/src/ATen/native/xpu/SpectralOps.cpp +++ b/src/ATen/native/xpu/SpectralOps.cpp @@ -1,6 +1,8 @@ -#include +#if defined(USE_ONEMKL) #include -#include +#else +#include +#endif // USE_ONEMKL namespace at::native { @@ -11,7 +13,13 @@ Tensor _fft_c2c_xpu( bool forward) { TORCH_CHECK(self.is_complex()); +#if defined(USE_ONEMKL) return native::xpu::_fft_c2c_mkl(self, dim, normalization, forward); +#else + Tensor out_cpu = native::_fft_c2c_mkl( + self.to(Device(at::kCPU)), dim, normalization, forward); + return out_cpu.to(Device(at::kXPU)); +#endif // USE_ONEMKL } Tensor& _fft_c2c_xpu_out( @@ -22,7 +30,15 @@ Tensor& _fft_c2c_xpu_out( Tensor& out) { TORCH_CHECK(self.is_complex()); +#if defined(USE_ONEMKL) return native::xpu::_fft_c2c_mkl_out(self, dim, normalization, forward, out); +#else + Tensor out_cpu = out.to(Device(at::kCPU)); + native::_fft_c2c_mkl_out( + self.to(Device(at::kCPU)), dim, normalization, forward, out_cpu); + out.copy_(out_cpu); + return out; +#endif // USE_ONEMKL } } // namespace at::native diff --git a/src/ATen/native/xpu/mkl/SpectralOps.cpp b/src/ATen/native/xpu/mkl/SpectralOps.cpp index bd3ad8b25..0dac86b51 100644 --- a/src/ATen/native/xpu/mkl/SpectralOps.cpp +++ b/src/ATen/native/xpu/mkl/SpectralOps.cpp @@ -1,3 +1,4 @@ +#if defined(USE_ONEMKL) #include #include #include @@ -398,3 +399,4 @@ Tensor& _fft_c2c_mkl_out( } } // namespace at::native::xpu +#endif // USE_ONEMKL diff --git a/src/ATen/native/xpu/mkl/SpectralOps.h b/src/ATen/native/xpu/mkl/SpectralOps.h index 0d66a6dae..504187397 100644 --- a/src/ATen/native/xpu/mkl/SpectralOps.h +++ b/src/ATen/native/xpu/mkl/SpectralOps.h @@ -1,5 +1,7 @@ #pragma once +#include + namespace at::native::xpu { TORCH_XPU_API Tensor _fft_c2c_mkl( diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 661e48d88..cdecb270a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -32,5 +32,8 @@ if(CLANG_FORMAT) add_dependencies(torch_xpu_ops CL_FORMAT_CSRCS) endif() -target_include_directories(torch_xpu_ops PUBLIC ${TORCH_XPU_OPS_ONEMKL_INCLUDE_DIR}) -target_link_libraries(torch_xpu_ops PUBLIC ${TORCH_XPU_OPS_ONEMKL_LIBRARIES}) +if(USE_ONEMKL) + target_compile_options(torch_xpu_ops PRIVATE "-DUSE_ONEMKL") + target_include_directories(torch_xpu_ops PUBLIC ${TORCH_XPU_OPS_ONEMKL_INCLUDE_DIR}) + target_link_libraries(torch_xpu_ops PUBLIC ${TORCH_XPU_OPS_ONEMKL_LIBRARIES}) +endif()