Skip to content

Commit

Permalink
feat(cli): download output to dir
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobLinCool committed Jul 25, 2024
1 parent dedba59 commit 2d00836
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions src/bin/gr.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::path::PathBuf;

use anyhow::Result;
use clap::{arg, Command};
use gradio::{Client, ClientOptions, PredictionInput};
Expand All @@ -7,6 +9,7 @@ async fn main() -> Result<()> {
let matches = cli().get_matches();

let token = matches.get_one::<String>("token");
let output = matches.get_one::<String>("output");

match matches.subcommand() {
Some(("run", sub_matches)) => {
Expand All @@ -16,7 +19,7 @@ async fn main() -> Result<()> {
.get_many::<String>("options")
.unwrap_or_default()
.collect();
run_command(space_id, route, options, token).await?;
run_command(space_id, route, options, token, output).await?;
}
Some(("list", sub_matches)) => {
let space_id = sub_matches.get_one::<String>("space_id").expect("required");
Expand All @@ -34,7 +37,8 @@ fn cli() -> Command {
Command::new("gr")
.version(env!("CARGO_PKG_VERSION"))
.about("Gradio Command Line Client")
.arg(arg!(--token <token> "The Hugging Face Access Token"))
.arg(arg!(-t --token <token> "The Hugging Face Access Token"))
.arg(arg!(-o --output <output> "Output directory, if specified, files will be saved to this directory"))
.subcommand_required(true)
.arg_required_else_help(true)
.subcommand(
Expand All @@ -58,6 +62,7 @@ async fn run_command(
route: &str,
options: Vec<&String>,
token: Option<&String>,
outdir: Option<&String>,
) -> Result<()> {
let route = format!("/{}", route.trim_start_matches('/'));

Expand Down Expand Up @@ -105,6 +110,7 @@ async fn run_command(
}
}

let http_client = client.http_client.clone();
let output = client.predict(&route, data).await?;
for (i, ret) in endpoint.returns.iter().enumerate() {
let value = output.get(i).expect("Missing return value");
Expand All @@ -117,11 +123,15 @@ async fn run_command(
};

if value.is_file() {
println!(
"{}: {}",
name,
value.clone().as_file()?.url.unwrap_or("".to_string())
);
let file = value.clone().as_file()?;
if let Some(outdir) = outdir {
let mut fp = PathBuf::from(outdir);
fp.push(format!("{}.{}", name, file.suggest_extension()));
file.save_to_path(&fp, Some(http_client.clone())).await?;
println!("{}: {}", name, fp.display());
} else {
println!("{}: {}", name, file.url.unwrap_or("".to_string()));
}
} else {
println!("{}: {}", name, value.clone().as_value()?);
}
Expand Down

0 comments on commit 2d00836

Please sign in to comment.