Skip to content

Commit

Permalink
refactor: split monolithic BUILD.local_cuda into template fragments
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Dec 14, 2024
1 parent 688c36e commit b7887dd
Show file tree
Hide file tree
Showing 32 changed files with 703 additions and 613 deletions.
59 changes: 21 additions & 38 deletions cuda/private/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
load("//cuda/private:template_helper.bzl", "template_helper")

def _to_forward_slash(s):
return s.replace("\\", "/")
Expand Down Expand Up @@ -103,42 +104,34 @@ def config_cuda_toolkit_and_nvcc(repository_ctx, cuda):
cuda: The struct returned from detect_cuda_toolkit
"""

# Generate @local_cuda//BUILD and @local_cuda//defs.bzl
defs_bzl_content = ""
defs_if_local_cuda = "def if_local_cuda(if_true, if_false = []):\n return %s\n"
# True: locally installed cuda toolkit
# False: hermatic cuda toolkit (components)
# None: cuda toolkit is not presented
is_local_cuda = None
if cuda.path != None:
# When using a special cuda toolkit path install, need to manually fix up the lib64 links
if cuda.path == "/usr/lib/nvidia-cuda-toolkit":
repository_ctx.symlink(cuda.path + "/bin", "cuda/bin")
repository_ctx.symlink("/usr/lib/x86_64-linux-gnu", "cuda/lib64")
else:
repository_ctx.symlink(cuda.path, "cuda")
repository_ctx.symlink(Label("//cuda:runtime/BUILD.local_cuda"), "BUILD")
defs_bzl_content += defs_if_local_cuda % "if_true"
is_local_cuda = True

# Generate @local_cuda//BUILD
if is_local_cuda == None:
repository_ctx.symlink(Label("//cuda/private:templates/BUILD.local_cuda_disabled"), "BUILD")
elif is_local_cuda:
libpath = "lib64" if _is_linux(repository_ctx) else "lib"
template_helper.generate_build(repository_ctx, libpath)
else:
repository_ctx.symlink(Label("//cuda:runtime/BUILD.local_cuda_disabled"), "BUILD")
defs_bzl_content += defs_if_local_cuda % "if_false"
repository_ctx.file("defs.bzl", defs_bzl_content)
# raise NotImplementedError("hermatic cuda toolchain is not implemented")
pass

# Generate @local_cuda//defs.bzl
template_helper.generate_defs_bzl(repository_ctx, is_local_cuda)

# Generate @local_cuda//toolchain/BUILD
tpl_label = Label(
"//cuda:templates/BUILD.local_toolchain_" +
("nvcc" if _is_linux(repository_ctx) else "nvcc_msvc"),
)
substitutions = {
"%{cuda_path}": _to_forward_slash(cuda.path) if cuda.path else "cuda-not-found",
"%{cuda_version}": "{}.{}".format(cuda.version_major, cuda.version_minor),
"%{nvcc_version_major}": str(cuda.nvcc_version_major),
"%{nvcc_version_minor}": str(cuda.nvcc_version_minor),
"%{nvlink_label}": cuda.nvlink_label,
"%{link_stub_label}": cuda.link_stub_label,
"%{bin2c_label}": cuda.bin2c_label,
"%{fatbinary_label}": cuda.fatbinary_label,
}
env_tmp = repository_ctx.os.environ.get("TMP", repository_ctx.os.environ.get("TEMP", None))
if env_tmp != None:
substitutions["%{env_tmp}"] = _to_forward_slash(env_tmp)
repository_ctx.template("toolchain/BUILD", tpl_label, substitutions = substitutions, executable = False)
template_helper.generate_toolchain_build(repository_ctx, cuda)

def detect_clang(repository_ctx):
"""Detect local clang installation.
Expand Down Expand Up @@ -178,20 +171,10 @@ def config_clang(repository_ctx, cuda, clang_path):
cuda: The struct returned from `detect_cuda_toolkit`
clang_path: Path to clang executable returned from `detect_clang`
"""
tpl_label = Label("//cuda:templates/BUILD.local_toolchain_clang")
substitutions = {
"%{clang_path}": _to_forward_slash(clang_path) if clang_path else "cuda-clang-not-found",
"%{cuda_path}": _to_forward_slash(cuda.path) if cuda.path else "cuda-not-found",
"%{cuda_version}": "{}.{}".format(cuda.version_major, cuda.version_minor),
"%{nvlink_label}": cuda.nvlink_label,
"%{link_stub_label}": cuda.link_stub_label,
"%{bin2c_label}": cuda.bin2c_label,
"%{fatbinary_label}": cuda.fatbinary_label,
}
repository_ctx.template("toolchain/clang/BUILD", tpl_label, substitutions = substitutions, executable = False)
template_helper.generate_toolchain_clang_build(repository_ctx, cuda, clang_path)

def config_disabled(repository_ctx):
repository_ctx.symlink(Label("//cuda:templates/BUILD.local_toolchain_disabled"), "toolchain/disabled/BUILD")
repository_ctx.symlink(Label("//cuda/private:templates/BUILD.local_toolchain_disabled"), "toolchain/disabled/BUILD")

def _local_cuda_impl(repository_ctx):
cuda = detect_cuda_toolkit(repository_ctx)
Expand Down
82 changes: 82 additions & 0 deletions cuda/private/template_helper.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
load("//cuda/private:templates/registry.bzl", "REGISTRY")

def _to_forward_slash(s):
return s.replace("\\", "/")

def _is_linux(ctx):
return ctx.os.name.startswith("linux")

def _is_windows(ctx):
return ctx.os.name.lower().startswith("windows")

def _generate_build(repository_ctx, libpath):
# stitch template fragment
fragments = [
Label("//cuda/private:templates/BUILD.local_cuda_shared"),
Label("//cuda/private:templates/BUILD.local_cuda_headers"),
Label("//cuda/private:templates/BUILD.local_cuda_build_setting"),
]
fragments.extend([Label("//cuda/private:templates/BUILD.{}".format(c)) for c in REGISTRY if len(REGISTRY[c]) > 0])

template_content = []
for frag in fragments:
template_content.append("# Generated from fragment " + str(frag))
template_content.append(repository_ctx.read(frag))

template_content = "\n".join(template_content)

template_path = repository_ctx.path("BUILD.tpl")
repository_ctx.file(template_path, content = template_content, executable = False)

substitutions = {
"%{component_name}": "cuda",
"%{libpath}": libpath,
}
repository_ctx.template("BUILD", template_path, substitutions = substitutions, executable = False)

def _generate_defs_bzl(repository_ctx, is_local_cuda):
tpl_label = Label("//cuda/private:templates/defs.bzl.tpl")
substitutions = {
"%{is_local_cuda}": str(is_local_cuda),
}
repository_ctx.template("defs.bzl", tpl_label, substitutions = substitutions, executable = False)

def _generate_toolchain_build(repository_ctx, cuda):
tpl_label = Label(
"//cuda/private:templates/BUILD.local_toolchain_" +
("nvcc" if _is_linux(repository_ctx) else "nvcc_msvc"),
)
substitutions = {
"%{cuda_path}": _to_forward_slash(cuda.path) if cuda.path else "cuda-not-found",
"%{cuda_version}": "{}.{}".format(cuda.version_major, cuda.version_minor),
"%{nvcc_version_major}": str(cuda.nvcc_version_major),
"%{nvcc_version_minor}": str(cuda.nvcc_version_minor),
"%{nvlink_label}": cuda.nvlink_label,
"%{link_stub_label}": cuda.link_stub_label,
"%{bin2c_label}": cuda.bin2c_label,
"%{fatbinary_label}": cuda.fatbinary_label,
}
env_tmp = repository_ctx.os.environ.get("TMP", repository_ctx.os.environ.get("TEMP", None))
if env_tmp != None:
substitutions["%{env_tmp}"] = _to_forward_slash(env_tmp)
repository_ctx.template("toolchain/BUILD", tpl_label, substitutions = substitutions, executable = False)

def _generate_toolchain_clang_build(repository_ctx, cuda, clang_path):
tpl_label = Label("//cuda/private:templates/BUILD.local_toolchain_clang")
substitutions = {
"%{clang_path}": _to_forward_slash(clang_path) if clang_path else "cuda-clang-not-found",
"%{cuda_path}": _to_forward_slash(cuda.path) if cuda.path else "cuda-not-found",
"%{cuda_version}": "{}.{}".format(cuda.version_major, cuda.version_minor),
"%{nvlink_label}": cuda.nvlink_label,
"%{link_stub_label}": cuda.link_stub_label,
"%{bin2c_label}": cuda.bin2c_label,
"%{fatbinary_label}": cuda.fatbinary_label,
}
repository_ctx.template("toolchain/clang/BUILD", tpl_label, substitutions = substitutions, executable = False)

template_helper = struct(
generate_build = _generate_build,
generate_defs_bzl = _generate_defs_bzl,
generate_toolchain_build = _generate_toolchain_build,
generate_toolchain_clang_build = _generate_toolchain_clang_build,
)
22 changes: 22 additions & 0 deletions cuda/private/templates/BUILD.cccl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
cc_library(
name = "cub",
hdrs = glob(
["%{component_name}/include/cub/**"],
allow_empty = True,
),
includes = [
"%{component_name}/include",
],
)

cc_library(
name = "thrust",
hdrs = glob(
["%{component_name}/include/thrust/**"],
allow_empty = True,
),
includes = [
"%{component_name}/include",
],
deps = [":cub"],
)
37 changes: 37 additions & 0 deletions cuda/private/templates/BUILD.cublas
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

cc_import_versioned_sos(
name = "cublas_so",
shared_library = "%{component_name}/%{libpath}/libcublas.so",
)

cc_import_versioned_sos(
name = "cublasLt_so",
shared_library = "%{component_name}/%{libpath}/libcublasLt.so",
)

cc_import(
name = "cublas_lib",
interface_library = "%{component_name}/%{libpath}/x64/cublas.lib",
system_provided = 1,
target_compatible_with = ["@platforms//os:windows"],
)

cc_import(
name = "cublasLt_lib",
interface_library = "%{component_name}/%{libpath}/x64/cublasLt.lib",
system_provided = 1,
target_compatible_with = ["@platforms//os:windows"],
)

cc_library(
name = "cublas",
deps = [
":headers",
] + if_linux([
":cublasLt_so",
":cublas_so",
]) + if_windows([
":cublasLt_lib",
":cublas_lib",
]),
)
95 changes: 95 additions & 0 deletions cuda/private/templates/BUILD.cudart
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
cc_import_versioned_sos(
name = "cudart_so",
shared_library = "%{component_name}/%{libpath}/libcudart.so",
)

cc_library(
name = "cudadevrt_a",
srcs = ["%{component_name}/%{libpath}/libcudadevrt.a"],
target_compatible_with = ["@platforms//os:linux"],
)

cc_import(
name = "cudart_lib",
interface_library = "%{component_name}/%{libpath}/x64/cudart.lib",
system_provided = 1,
target_compatible_with = ["@platforms//os:windows"],
)

cc_import(
name = "cudadevrt_lib",
interface_library = "%{component_name}/%{libpath}/x64/cudadevrt.lib",
system_provided = 1,
target_compatible_with = ["@platforms//os:windows"],
)

# Note: do not use this target directly, use the configurable label_flag
# @rules_cuda//cuda:runtime instead.
cc_library(
name = "cuda_runtime",
linkopts = if_linux([
"-ldl",
"-lpthread",
"-lrt",
]),
deps = [
":headers",
] + if_linux([
# devrt is required for jit linking when rdc is enabled
":cudadevrt_a",
":cudart_so",
]) + if_windows([
# devrt is required for jit linking when rdc is enabled
":cudadevrt_lib",
":cudart_lib",
]),
# FIXME:
# visibility = ["@rules_cuda//cuda:__pkg__"],
)

# Note: do not use this target directly, use the configurable label_flag
# @rules_cuda//cuda:runtime instead.
cc_library(
name = "cuda_runtime_static",
srcs = ["%{component_name}/%{libpath}/libcudart_static.a"],
hdrs = [":_cuda_header_files"],
includes = ["%{component_name}/include"],
linkopts = if_linux([
"-ldl",
"-lpthread",
"-lrt",
]),
deps = [":cudadevrt_a"],
# FIXME:
# visibility = ["@rules_cuda//cuda:__pkg__"],
)

cc_library(
name = "no_cuda_runtime",
# FIXME:
# visibility = ["@rules_cuda//cuda:__pkg__"],
)

cc_import(
name = "cuda_so",
shared_library = "%{component_name}/%{libpath}/stubs/libcuda.so",
target_compatible_with = ["@platforms//os:linux"],
)

cc_import(
name = "cuda_lib",
interface_library = "%{component_name}/%{libpath}/x64/cuda.lib",
system_provided = 1,
target_compatible_with = ["@platforms//os:windows"],
)

cc_library(
name = "cuda",
deps = [
":headers",
] + if_linux([
":cuda_so",
]) + if_windows([
":cuda_lib",
]),
)
36 changes: 36 additions & 0 deletions cuda/private/templates/BUILD.cufft
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
cc_import_versioned_sos(
name = "cufft_so",
shared_library = "%{component_name}/%{libpath}/libcufft.so",
)

cc_import(
name = "cufft_lib",
interface_library = "%{component_name}/%{libpath}/x64/cufft.lib",
system_provided = 1,
target_compatible_with = ["@platforms//os:windows"],
)

cc_import_versioned_sos(
name = "cufftw_so",
shared_library = "%{component_name}/%{libpath}/libcufftw.so",
)

cc_import(
name = "cufftw_lib",
interface_library = "%{component_name}/%{libpath}/x64/cufftw.lib",
system_provided = 1,
target_compatible_with = ["@platforms//os:windows"],
)

cc_library(
name = "cufft",
deps = [
":headers",
] + if_linux([
":cufft_so",
":cufftw_so",
]) + if_windows([
":cufft_lib",
":cufftw_lib",
]),
)
Empty file.
Loading

0 comments on commit b7887dd

Please sign in to comment.