Skip to content

Commit

Permalink
Improve handling of compile_data with mixed sources (#3176)
Browse files Browse the repository at this point in the history
I'm not very well versed in starlark nor the rules_rust codebase, so
feel free to ignore this and address the issue in a more fitting way,
but this fixes #3171 and a related issue for me.
  • Loading branch information
martingms authored Jan 13, 2025
1 parent 5e426fa commit bb74a65
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 21 deletions.
26 changes: 12 additions & 14 deletions rust/private/rust.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _rust_library_common(ctx, crate_type):
crate_root = getattr(ctx.file, "crate_root", None)
if not crate_root:
crate_root = crate_root_src(ctx.attr.name, ctx.attr.crate_name, ctx.files.srcs, crate_type)
srcs, crate_root = transform_sources(ctx, ctx.files.srcs, crate_root)
srcs, compile_data, crate_root = transform_sources(ctx, ctx.files.srcs, ctx.files.compile_data, crate_root)

# Determine unique hash for this rlib.
# Note that we don't include a hash for `cdylib` and `staticlib` since they are meant to be consumed externally
Expand Down Expand Up @@ -202,7 +202,7 @@ def _rust_library_common(ctx, crate_type):
rustc_env_files = ctx.files.rustc_env_files,
is_test = False,
data = depset(ctx.files.data),
compile_data = depset(ctx.files.compile_data),
compile_data = depset(compile_data),
compile_data_targets = depset(ctx.attr.compile_data),
owner = ctx.label,
),
Expand Down Expand Up @@ -233,7 +233,7 @@ def _rust_binary_impl(ctx):
crate_root = getattr(ctx.file, "crate_root", None)
if not crate_root:
crate_root = crate_root_src(ctx.attr.name, ctx.attr.crate_name, ctx.files.srcs, ctx.attr.crate_type)
srcs, crate_root = transform_sources(ctx, ctx.files.srcs, crate_root)
srcs, compile_data, crate_root = transform_sources(ctx, ctx.files.srcs, ctx.files.compile_data, crate_root)

providers = rustc_compile_action(
ctx = ctx,
Expand All @@ -254,7 +254,7 @@ def _rust_binary_impl(ctx):
rustc_env_files = ctx.files.rustc_env_files,
is_test = False,
data = depset(ctx.files.data),
compile_data = depset(ctx.files.compile_data),
compile_data = depset(compile_data),
compile_data_targets = depset(ctx.attr.compile_data),
owner = ctx.label,
),
Expand Down Expand Up @@ -330,13 +330,11 @@ def _rust_test_impl(ctx):
),
)

srcs, crate_root = transform_sources(ctx, ctx.files.srcs, getattr(ctx.file, "crate_root", None))
# Need to consider all src files together when transforming
srcs = depset(ctx.files.srcs, transitive = [crate.srcs]).to_list()
compile_data = depset(ctx.files.compile_data, transitive = [crate.compile_data]).to_list()
srcs, compile_data, crate_root = transform_sources(ctx, srcs, compile_data, getattr(ctx.file, "crate_root", None))

# Optionally join compile data
if crate.compile_data:
compile_data = depset(ctx.files.compile_data, transitive = [crate.compile_data])
else:
compile_data = depset(ctx.files.compile_data)
if crate.compile_data_targets:
compile_data_targets = depset(ctx.attr.compile_data, transitive = [crate.compile_data_targets])
else:
Expand All @@ -360,7 +358,7 @@ def _rust_test_impl(ctx):
name = crate_name,
type = crate_type,
root = crate.root,
srcs = depset(srcs, transitive = [crate.srcs]),
srcs = depset(srcs),
deps = depset(deps, transitive = [crate.deps]),
proc_macro_deps = depset(proc_macro_deps, transitive = [crate.proc_macro_deps]),
aliases = aliases,
Expand All @@ -370,7 +368,7 @@ def _rust_test_impl(ctx):
rustc_env = rustc_env,
rustc_env_files = rustc_env_files,
is_test = True,
compile_data = compile_data,
compile_data = depset(compile_data),
compile_data_targets = compile_data_targets,
wrapped_crate_type = crate.type,
owner = ctx.label,
Expand All @@ -381,7 +379,7 @@ def _rust_test_impl(ctx):
if not crate_root:
crate_root_type = "lib" if ctx.attr.use_libtest_harness else "bin"
crate_root = crate_root_src(ctx.attr.name, ctx.attr.crate_name, ctx.files.srcs, crate_root_type)
srcs, crate_root = transform_sources(ctx, ctx.files.srcs, crate_root)
srcs, compile_data, crate_root = transform_sources(ctx, ctx.files.srcs, ctx.files.compile_data, crate_root)

if toolchain._incompatible_change_rust_test_compilation_output_directory:
output = ctx.actions.declare_file(
Expand Down Expand Up @@ -420,7 +418,7 @@ def _rust_test_impl(ctx):
rustc_env = rustc_env,
rustc_env_files = ctx.files.rustc_env_files,
is_test = True,
compile_data = depset(ctx.files.compile_data),
compile_data = depset(compile_data),
compile_data_targets = depset(ctx.attr.compile_data),
owner = ctx.label,
)
Expand Down
18 changes: 13 additions & 5 deletions rust/private/utils.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def determine_lib_name(name, crate_type, toolchain, lib_hash = None):
extension = extension,
)

def transform_sources(ctx, srcs, crate_root):
def transform_sources(ctx, srcs, compile_data, crate_root):
"""Creates symlinks of the source files if needed.
Rustc assumes that the source files are located next to the crate root.
Expand All @@ -802,25 +802,33 @@ def transform_sources(ctx, srcs, crate_root):
Args:
ctx (struct): The current rule's context.
srcs (List[File]): The sources listed in the `srcs` attribute
compile_data (List[File]): The sources listed in the `compile_data`
attribute
crate_root (File): The file specified in the `crate_root` attribute,
if it exists, otherwise None
Returns:
Tuple(List[File], File): The transformed srcs and crate_root
Tuple(List[File], List[File], File): The transformed srcs, compile_data
and crate_root
"""
has_generated_sources = len([src for src in srcs if not src.is_source]) > 0
has_generated_sources = (
len([src for src in srcs if not src.is_source]) +
len([src for src in compile_data if not src.is_source]) >
0
)

if not has_generated_sources:
return srcs, crate_root
return srcs, compile_data, crate_root

package_root = paths.join(ctx.label.workspace_root, ctx.label.package)
generated_sources = [_symlink_for_non_generated_source(ctx, src, package_root) for src in srcs if src != crate_root]
generated_compile_data = [_symlink_for_non_generated_source(ctx, src, package_root) for src in compile_data]
generated_root = crate_root
if crate_root:
generated_root = _symlink_for_non_generated_source(ctx, crate_root, package_root)
generated_sources.append(generated_root)

return generated_sources, generated_root
return generated_sources, generated_compile_data, generated_root

def get_edition(attr, toolchain, label):
"""Returns the Rust edition from either the current rule's attributes or the current `rust_toolchain`
Expand Down
12 changes: 12 additions & 0 deletions test/unit/compile_data/compile_data_gen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/// Data loaded from generated compile data
pub const COMPILE_DATA: &str = include_str!("generated.txt");

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_compile_data_contents() {
assert_eq!(COMPILE_DATA.trim_end(), "generated compile data contents");
}
}
19 changes: 19 additions & 0 deletions test/unit/compile_data/compile_data_gen_srcs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
pub mod generated;

/// Data loaded from compile data
pub const COMPILE_DATA: &str = include_str!("compile_data.txt");

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_compile_data_contents() {
assert_eq!(COMPILE_DATA.trim_end(), "compile data contents");
}

#[test]
fn test_generated_src() {
assert_eq!(generated::GENERATED, "generated");
}
}
36 changes: 36 additions & 0 deletions test/unit/compile_data/compile_data_test.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,42 @@ def _define_test_targets():
crate = ":compile_data_env",
)

native.genrule(
name = "generated_compile_data",
outs = ["generated.txt"],
cmd = "echo 'generated compile data contents' > $@",
)

rust_library(
name = "compile_data_gen",
srcs = ["compile_data_gen.rs"],
compile_data = [":generated.txt"],
edition = "2021",
)

rust_test(
name = "compile_data_gen_unit_test",
crate = ":compile_data_gen",
)

native.genrule(
name = "generated_src",
outs = ["generated.rs"],
cmd = """echo 'pub const GENERATED: &str = "generated";' > $@""",
)

rust_library(
name = "compile_data_gen_srcs",
srcs = ["compile_data_gen_srcs.rs", ":generated.rs"],
compile_data = ["compile_data.txt"],
edition = "2021",
)

rust_test(
name = "compile_data_gen_srcs_unit_test",
crate = ":compile_data_gen_srcs",
)

def compile_data_test_suite(name):
"""Entry-point macro called from the BUILD file.
Expand Down
4 changes: 2 additions & 2 deletions test/unit/location_expansion/location_expansion_test.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ load("//test/unit:common.bzl", "assert_action_mnemonic", "assert_argv_contains")
def _location_expansion_rustc_flags_test(ctx):
env = analysistest.begin(ctx)
tut = analysistest.target_under_test(env)
action = tut.actions[0]
action = tut.actions[1]
assert_action_mnemonic(env, action, "Rustc")
assert_argv_contains(env, action, "test/unit/location_expansion/mylibrary.rs")
assert_argv_contains(env, action, ctx.bin_dir.path + "/test/unit/location_expansion/mylibrary.rs")
expected = "@${pwd}/" + ctx.bin_dir.path + "/test/unit/location_expansion/generated_flag.data"
assert_argv_contains(env, action, expected)
return analysistest.end(env)
Expand Down

0 comments on commit bb74a65

Please sign in to comment.