Skip to content

Commit

Permalink
add refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Jan 19, 2025
1 parent 0da839e commit 249ccb9
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions truss-transfer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ static CACHE_DIR: &str = "/cache/org/artifacts";
static BLOB_DOWNLOAD_TIMEOUT_SECS: u64 = 7200;
static BASETEN_FS_ENABLED_ENV_VAR: &str = "BASETEN_FS_ENABLED";
static TRUSS_TRANSFER_NUM_WORKERS_DEFAULT: usize = 64;
static TRUSS_TRANSFER_DOWNLOAD_DIR_ENV_VAR: &str = "TRUSS_TRANSFER_DOWNLOAD_DIR";

// Global lock to serialize downloads
static GLOBAL_DOWNLOAD_LOCK: OnceLock<Arc<Mutex<()>>> = OnceLock::new();
Expand All @@ -31,6 +32,21 @@ fn get_global_lock() -> &'static Arc<Mutex<()>> {
GLOBAL_DOWNLOAD_LOCK.get_or_init(|| Arc::new(Mutex::new(())))
}

fn resolve_truss_transfer_download_dir(optional_download_dir: Option<String>) -> String {
// Order:
// 1. optional_download_dir, if provided
// 2. TRUSS_TRANSFER_DOWNLOAD_DIR_ENV_VAR
// else: raise error
optional_download_dir
.or_else(|| env::var(TRUSS_TRANSFER_DOWNLOAD_DIR_ENV_VAR).ok())
.unwrap_or_else(|| {
panic!(
"No download directory provided. Please set `export {}=/path/to/dir` or pass it as an argument.",
TRUSS_TRANSFER_DOWNLOAD_DIR_ENV_VAR
)
})
}

/// Corresponds to `Resolution` in the Python code
#[derive(Debug, Deserialize)]
struct Resolution {
Expand All @@ -57,17 +73,19 @@ struct BasetenPointerManifest {
}

/// Python-callable function to read the manifest and download data.
/// By default, uses 64 concurrent workers if you don't specify `num_workers`.
/// By default, it will use the `TRUSS_TRANSFER_DOWNLOAD_DIR` environment variable.
#[pyfunction]
#[pyo3(signature = (download_dir))]
fn lazy_data_resolve(download_dir: String) -> PyResult<()> {
fn lazy_data_resolve(download_dir: Option<String>) -> PyResult<()> {
lazy_data_resolve_entrypoint(download_dir).map_err(|err| PyException::new_err(err.to_string()))
}

/// Shared entrypoint for both Python and CLI
fn lazy_data_resolve_entrypoint(download_dir: String) -> Result<()> {
fn lazy_data_resolve_entrypoint(download_dir: Option<String>) -> Result<()> {
let num_workers = TRUSS_TRANSFER_NUM_WORKERS_DEFAULT;

let download_dir = resolve_truss_transfer_download_dir(download_dir);

// Ensure the global lock is initialized
let lock = get_global_lock();

Expand Down Expand Up @@ -335,16 +353,9 @@ fn main() -> anyhow::Result<()> {
"[INFO] truss_transfer_cli, version: {}",
env!("CARGO_PKG_VERSION")
);
let args: Vec<String> = std::env::args().collect();

if args.len() < 2 {
println!("Usage: {} <download_dir>", args[0]);
return Ok(());
}

let download_dir = &args[1];

println!("[INFO] Invoking lazy_data_resolve_async with download_dir='{download_dir}'");
let download_dir = std::env::args().nth(1);

lazy_data_resolve_entrypoint(download_dir.into())
}
Expand Down

0 comments on commit 249ccb9

Please sign in to comment.