Skip to content

Commit

Permalink
basic working jax-cuda implementation for forwards.
Browse files Browse the repository at this point in the history
  • Loading branch information
nickjbrowning committed Jan 15, 2024
1 parent 343c333 commit cfa089c
Show file tree
Hide file tree
Showing 14 changed files with 389 additions and 139 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def run(self):
CUDA_HOME = os.environ.get("CUDA_HOME")
if CUDA_HOME is not None:
cmake_options.append(f"-DCUDA_TOOLKIT_ROOT_DIR={CUDA_HOME}")
cmake_options.append("-DSPHERICART_ENABLE_CUDA=ON")

if sys.platform.startswith("darwin"):
cmake_options.append("-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=11.0")
Expand Down
48 changes: 41 additions & 7 deletions sphericart-jax/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ endif()

project(sphericart_jax CXX)

include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
enable_language(CUDA)
set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE BOOL "" FORCE)
else()
message(STATUS "Could not find a CUDA compiler")
endif()


# Set a default build type if none was specified
if (${CMAKE_CURRENT_SOURCE_DIR} STREQUAL ${CMAKE_SOURCE_DIR})
if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "")
Expand All @@ -23,19 +33,14 @@ endif()

find_package(pybind11 CONFIG REQUIRED)



set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE)
add_subdirectory(sphericart EXCLUDE_FROM_ALL)



# CPU op library
pybind11_add_module(sphericart_jax ${CMAKE_CURRENT_LIST_DIR}/src/jax.cpp)
set(CPU_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/jax.cpp)
pybind11_add_module(sphericart_jax ${CPU_SOURCES})
install(TARGETS sphericart_jax DESTINATION sphericart_jax)



target_link_libraries(sphericart_jax PUBLIC sphericart)
target_compile_features(sphericart_jax PUBLIC cxx_std_17)

Expand All @@ -48,3 +53,32 @@ install(TARGETS sphericart_jax
LIBRARY DESTINATION "lib"
)

# Include the CUDA extensions if possible
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER AND SPHERICART_ENABLE_CUDA)

enable_language(CUDA)
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
set(CUDA_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/jax_cuda.cu ${CMAKE_CURRENT_LIST_DIR}/src/jax_cuda.cpp)
pybind11_add_module(sphericart_jax_gpu ${CUDA_SOURCES})
set_target_properties(sphericart_jax_gpu PROPERTIES CUDA_ARCHITECTURES native)
install(TARGETS sphericart_jax_gpu DESTINATION sphericart_jax_gpu)

target_link_libraries(sphericart_jax_gpu PUBLIC sphericart)
target_compile_features(sphericart_jax_gpu PUBLIC cxx_std_17)

target_include_directories(sphericart_jax_gpu PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
)



install(TARGETS sphericart_jax_gpu
LIBRARY DESTINATION "lib"
)
endif()



43 changes: 43 additions & 0 deletions sphericart-jax/include/sphericart/jax_cuda.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

#ifndef _JAX_CUDA_H_
#define _JAX_CUDA_H_

#include <cuda_runtime.h>
#include <cuda_runtime_api.h>

struct SphDescriptor {
std::int64_t n_samples;
std::int64_t lmax;
bool normalize;
};

namespace sphericart_jax {

namespace cuda {

void apply_cuda_sph_f32(cudaStream_t stream, void **in, const char *opaque,
std::size_t opaque_len);

void apply_cuda_sph_f64(cudaStream_t stream, void **in, const char *opaque,
std::size_t opaque_len);

void apply_cuda_sph_with_gradients_f32(cudaStream_t stream, void **in,
const char *opaque,
std::size_t opaque_len);

void apply_cuda_sph_with_gradients_f64(cudaStream_t stream, void **in,
const char *opaque,
std::size_t opaque_len);

void apply_cuda_sph_with_hessians_f32(cudaStream_t stream, void **in,
const char *opaque,
std::size_t opaque_len);

void apply_cuda_sph_with_hessians_f64(cudaStream_t stream, void **in,
const char *opaque,
std::size_t opaque_len);

} // namespace cuda
} // namespace sphericart_jax

#endif
16 changes: 16 additions & 0 deletions sphericart-jax/include/sphericart/pybind11_kernel_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ template <typename T> pybind11::capsule EncapsulateFunction(T *fn) {
return pybind11::capsule(bit_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
}

template <typename T> std::string PackDescriptorAsString(const T &descriptor) {
return std::string(bit_cast<const char *>(&descriptor), sizeof(T));
}

template <typename T> pybind11::bytes PackDescriptor(const T &descriptor) {
return pybind11::bytes(PackDescriptorAsString(descriptor));
}

template <typename T>
const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Invalid opaque object size");
}
return bit_cast<const T *>(opaque);
}

} // namespace sphericart_jax

#endif
22 changes: 17 additions & 5 deletions sphericart-jax/python/sphericart/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
import jax
from .lib import sphericart_jax
from .spherical_harmonics import spherical_harmonics


# register the operations to xla
for _name, _value in sphericart_jax.registrations().items():
if _name.startswith("cpu_"):
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="cpu")
elif _name.startswith("cuda_"):
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="gpu")
else:
raise NotImplementedError(f"Unsupported target {_name}")
raise NotImplementedError(f"Unsupported target in sphericart_jax_cpu {_name}")

try:
from .lib import sphericart_jax_gpu
# register the operations to xla
for _name, _value in sphericart_jax_gpu.registrations().items():
if _name.startswith("cuda_"):
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="gpu")
else:
raise NotImplementedError(f"Unsupported target in sphericart_jax_gpu {_name}")


except:
pass



15 changes: 10 additions & 5 deletions sphericart-jax/python/sphericart/jax/sph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .dsph import dsph
from .utils import default_layouts

from .lib.sphericart_jax_gpu import build_sph_descriptor

# register the sph primitive
_sph_p = core.Primitive("sph_fwd")
Expand Down Expand Up @@ -91,6 +92,7 @@ def sph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c):
out_shape = xyz_shape[:-1] + [
sph_size,
]
print (out_shape)
n_samples = math.prod(xyz_shape[:-1])

# make sure we dispatch to the correct implementation
Expand All @@ -101,6 +103,8 @@ def sph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c):
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

descriptor = build_sph_descriptor(n_samples, l_max_c, bool(normalized))

return custom_call(
op_name,
# Output types
Expand All @@ -109,14 +113,15 @@ def sph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c):
],
# inputs to the binded functions
operands=[
xyz,
mlir.ir_constant(l_max_c),
normalized,
mlir.ir_constant(n_samples),
xyz
# mlir.ir_constant(l_max_c),
# normalized,
# mlir.ir_constant(n_samples),
],
# Layout specification:
operand_layouts=default_layouts(xyz_shape, (), (), ()),
operand_layouts=default_layouts(xyz_shape),
result_layouts=default_layouts(out_shape),
backend_config=descriptor
).results


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
# jax.config.update("jax_platform_name", "cpu")
import jax.numpy as jnp
import jax._src.test_util as jtu
import sphericart.jax
from sphericart.jax.spherical_harmonics import spherical_harmonics

def xyz():
key = jax.random.PRNGKey(0)
return 6 * jax.random.normal(key, (100, 3))


def compute(xyz):
sph = sphericart.jax.spherical_harmonics(l_max=4, normalized=False, xyz=xyz)
sph = spherical_harmonics(l_max=4, normalized=False, xyz=xyz)
assert jnp.linalg.norm(sph) != 0.0
return sph.sum()


xyzs = jax.device_put(xyz(), device=jax.devices('gpu')[0])

sph = compute(xyz())

print (sph)
print ("sum sph:", sph)
print ("cuda jax succesful")
1 change: 1 addition & 0 deletions sphericart-jax/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def run(self):
CUDA_HOME = os.environ.get("CUDA_HOME")
if CUDA_HOME is not None:
cmake_options.append(f"-DCUDA_TOOLKIT_ROOT_DIR={CUDA_HOME}")
cmake_options.append("-DSPHERICART_ENABLE_CUDA=ON")

if sys.platform.startswith("darwin"):
cmake_options.append("-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=11.0")
Expand Down
Loading

0 comments on commit cfa089c

Please sign in to comment.