-
-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: split monolithic BUILD.local_cuda into template fragments
- Loading branch information
Showing
32 changed files
with
703 additions
and
613 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
load("//cuda/private:templates/registry.bzl", "REGISTRY") | ||
|
||
def _to_forward_slash(s): | ||
return s.replace("\\", "/") | ||
|
||
def _is_linux(ctx): | ||
return ctx.os.name.startswith("linux") | ||
|
||
def _is_windows(ctx): | ||
return ctx.os.name.lower().startswith("windows") | ||
|
||
def _generate_build(repository_ctx, libpath): | ||
# stitch template fragment | ||
fragments = [ | ||
Label("//cuda/private:templates/BUILD.local_cuda_shared"), | ||
Label("//cuda/private:templates/BUILD.local_cuda_headers"), | ||
Label("//cuda/private:templates/BUILD.local_cuda_build_setting"), | ||
] | ||
fragments.extend([Label("//cuda/private:templates/BUILD.{}".format(c)) for c in REGISTRY if len(REGISTRY[c]) > 0]) | ||
|
||
template_content = [] | ||
for frag in fragments: | ||
template_content.append("# Generated from fragment " + str(frag)) | ||
template_content.append(repository_ctx.read(frag)) | ||
|
||
template_content = "\n".join(template_content) | ||
|
||
template_path = repository_ctx.path("BUILD.tpl") | ||
repository_ctx.file(template_path, content = template_content, executable = False) | ||
|
||
substitutions = { | ||
"%{component_name}": "cuda", | ||
"%{libpath}": libpath, | ||
} | ||
repository_ctx.template("BUILD", template_path, substitutions = substitutions, executable = False) | ||
|
||
def _generate_defs_bzl(repository_ctx, is_local_cuda): | ||
tpl_label = Label("//cuda/private:templates/defs.bzl.tpl") | ||
substitutions = { | ||
"%{is_local_cuda}": str(is_local_cuda), | ||
} | ||
repository_ctx.template("defs.bzl", tpl_label, substitutions = substitutions, executable = False) | ||
|
||
def _generate_toolchain_build(repository_ctx, cuda): | ||
tpl_label = Label( | ||
"//cuda/private:templates/BUILD.local_toolchain_" + | ||
("nvcc" if _is_linux(repository_ctx) else "nvcc_msvc"), | ||
) | ||
substitutions = { | ||
"%{cuda_path}": _to_forward_slash(cuda.path) if cuda.path else "cuda-not-found", | ||
"%{cuda_version}": "{}.{}".format(cuda.version_major, cuda.version_minor), | ||
"%{nvcc_version_major}": str(cuda.nvcc_version_major), | ||
"%{nvcc_version_minor}": str(cuda.nvcc_version_minor), | ||
"%{nvlink_label}": cuda.nvlink_label, | ||
"%{link_stub_label}": cuda.link_stub_label, | ||
"%{bin2c_label}": cuda.bin2c_label, | ||
"%{fatbinary_label}": cuda.fatbinary_label, | ||
} | ||
env_tmp = repository_ctx.os.environ.get("TMP", repository_ctx.os.environ.get("TEMP", None)) | ||
if env_tmp != None: | ||
substitutions["%{env_tmp}"] = _to_forward_slash(env_tmp) | ||
repository_ctx.template("toolchain/BUILD", tpl_label, substitutions = substitutions, executable = False) | ||
|
||
def _generate_toolchain_clang_build(repository_ctx, cuda, clang_path): | ||
tpl_label = Label("//cuda/private:templates/BUILD.local_toolchain_clang") | ||
substitutions = { | ||
"%{clang_path}": _to_forward_slash(clang_path) if clang_path else "cuda-clang-not-found", | ||
"%{cuda_path}": _to_forward_slash(cuda.path) if cuda.path else "cuda-not-found", | ||
"%{cuda_version}": "{}.{}".format(cuda.version_major, cuda.version_minor), | ||
"%{nvlink_label}": cuda.nvlink_label, | ||
"%{link_stub_label}": cuda.link_stub_label, | ||
"%{bin2c_label}": cuda.bin2c_label, | ||
"%{fatbinary_label}": cuda.fatbinary_label, | ||
} | ||
repository_ctx.template("toolchain/clang/BUILD", tpl_label, substitutions = substitutions, executable = False) | ||
|
||
template_helper = struct( | ||
generate_build = _generate_build, | ||
generate_defs_bzl = _generate_defs_bzl, | ||
generate_toolchain_build = _generate_toolchain_build, | ||
generate_toolchain_clang_build = _generate_toolchain_clang_build, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
cc_library( | ||
name = "cub", | ||
hdrs = glob( | ||
["%{component_name}/include/cub/**"], | ||
allow_empty = True, | ||
), | ||
includes = [ | ||
"%{component_name}/include", | ||
], | ||
) | ||
|
||
cc_library( | ||
name = "thrust", | ||
hdrs = glob( | ||
["%{component_name}/include/thrust/**"], | ||
allow_empty = True, | ||
), | ||
includes = [ | ||
"%{component_name}/include", | ||
], | ||
deps = [":cub"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
|
||
cc_import_versioned_sos( | ||
name = "cublas_so", | ||
shared_library = "%{component_name}/%{libpath}/libcublas.so", | ||
) | ||
|
||
cc_import_versioned_sos( | ||
name = "cublasLt_so", | ||
shared_library = "%{component_name}/%{libpath}/libcublasLt.so", | ||
) | ||
|
||
cc_import( | ||
name = "cublas_lib", | ||
interface_library = "%{component_name}/%{libpath}/x64/cublas.lib", | ||
system_provided = 1, | ||
target_compatible_with = ["@platforms//os:windows"], | ||
) | ||
|
||
cc_import( | ||
name = "cublasLt_lib", | ||
interface_library = "%{component_name}/%{libpath}/x64/cublasLt.lib", | ||
system_provided = 1, | ||
target_compatible_with = ["@platforms//os:windows"], | ||
) | ||
|
||
cc_library( | ||
name = "cublas", | ||
deps = [ | ||
":headers", | ||
] + if_linux([ | ||
":cublasLt_so", | ||
":cublas_so", | ||
]) + if_windows([ | ||
":cublasLt_lib", | ||
":cublas_lib", | ||
]), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
cc_import_versioned_sos( | ||
name = "cudart_so", | ||
shared_library = "%{component_name}/%{libpath}/libcudart.so", | ||
) | ||
|
||
cc_library( | ||
name = "cudadevrt_a", | ||
srcs = ["%{component_name}/%{libpath}/libcudadevrt.a"], | ||
target_compatible_with = ["@platforms//os:linux"], | ||
) | ||
|
||
cc_import( | ||
name = "cudart_lib", | ||
interface_library = "%{component_name}/%{libpath}/x64/cudart.lib", | ||
system_provided = 1, | ||
target_compatible_with = ["@platforms//os:windows"], | ||
) | ||
|
||
cc_import( | ||
name = "cudadevrt_lib", | ||
interface_library = "%{component_name}/%{libpath}/x64/cudadevrt.lib", | ||
system_provided = 1, | ||
target_compatible_with = ["@platforms//os:windows"], | ||
) | ||
|
||
# Note: do not use this target directly, use the configurable label_flag | ||
# @rules_cuda//cuda:runtime instead. | ||
cc_library( | ||
name = "cuda_runtime", | ||
linkopts = if_linux([ | ||
"-ldl", | ||
"-lpthread", | ||
"-lrt", | ||
]), | ||
deps = [ | ||
":headers", | ||
] + if_linux([ | ||
# devrt is required for jit linking when rdc is enabled | ||
":cudadevrt_a", | ||
":cudart_so", | ||
]) + if_windows([ | ||
# devrt is required for jit linking when rdc is enabled | ||
":cudadevrt_lib", | ||
":cudart_lib", | ||
]), | ||
# FIXME: | ||
# visibility = ["@rules_cuda//cuda:__pkg__"], | ||
) | ||
|
||
# Note: do not use this target directly, use the configurable label_flag | ||
# @rules_cuda//cuda:runtime instead. | ||
cc_library( | ||
name = "cuda_runtime_static", | ||
srcs = ["%{component_name}/%{libpath}/libcudart_static.a"], | ||
hdrs = [":_cuda_header_files"], | ||
includes = ["%{component_name}/include"], | ||
linkopts = if_linux([ | ||
"-ldl", | ||
"-lpthread", | ||
"-lrt", | ||
]), | ||
deps = [":cudadevrt_a"], | ||
# FIXME: | ||
# visibility = ["@rules_cuda//cuda:__pkg__"], | ||
) | ||
|
||
cc_library( | ||
name = "no_cuda_runtime", | ||
# FIXME: | ||
# visibility = ["@rules_cuda//cuda:__pkg__"], | ||
) | ||
|
||
cc_import( | ||
name = "cuda_so", | ||
shared_library = "%{component_name}/%{libpath}/stubs/libcuda.so", | ||
target_compatible_with = ["@platforms//os:linux"], | ||
) | ||
|
||
cc_import( | ||
name = "cuda_lib", | ||
interface_library = "%{component_name}/%{libpath}/x64/cuda.lib", | ||
system_provided = 1, | ||
target_compatible_with = ["@platforms//os:windows"], | ||
) | ||
|
||
cc_library( | ||
name = "cuda", | ||
deps = [ | ||
":headers", | ||
] + if_linux([ | ||
":cuda_so", | ||
]) + if_windows([ | ||
":cuda_lib", | ||
]), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
cc_import_versioned_sos( | ||
name = "cufft_so", | ||
shared_library = "%{component_name}/%{libpath}/libcufft.so", | ||
) | ||
|
||
cc_import( | ||
name = "cufft_lib", | ||
interface_library = "%{component_name}/%{libpath}/x64/cufft.lib", | ||
system_provided = 1, | ||
target_compatible_with = ["@platforms//os:windows"], | ||
) | ||
|
||
cc_import_versioned_sos( | ||
name = "cufftw_so", | ||
shared_library = "%{component_name}/%{libpath}/libcufftw.so", | ||
) | ||
|
||
cc_import( | ||
name = "cufftw_lib", | ||
interface_library = "%{component_name}/%{libpath}/x64/cufftw.lib", | ||
system_provided = 1, | ||
target_compatible_with = ["@platforms//os:windows"], | ||
) | ||
|
||
cc_library( | ||
name = "cufft", | ||
deps = [ | ||
":headers", | ||
] + if_linux([ | ||
":cufft_so", | ||
":cufftw_so", | ||
]) + if_windows([ | ||
":cufft_lib", | ||
":cufftw_lib", | ||
]), | ||
) |
Empty file.
Oops, something went wrong.