From 2d00836933a85f2a955d6fae765e8596a967c6d8 Mon Sep 17 00:00:00 2001 From: JacobLinCool Date: Fri, 26 Jul 2024 04:31:13 +0800 Subject: [PATCH] feat(cli): download output to dir --- src/bin/gr.rs | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/bin/gr.rs b/src/bin/gr.rs index 18213ab..dd2e7c7 100644 --- a/src/bin/gr.rs +++ b/src/bin/gr.rs @@ -1,3 +1,5 @@ +use std::path::PathBuf; + use anyhow::Result; use clap::{arg, Command}; use gradio::{Client, ClientOptions, PredictionInput}; @@ -7,6 +9,7 @@ async fn main() -> Result<()> { let matches = cli().get_matches(); let token = matches.get_one::("token"); + let output = matches.get_one::("output"); match matches.subcommand() { Some(("run", sub_matches)) => { @@ -16,7 +19,7 @@ async fn main() -> Result<()> { .get_many::("options") .unwrap_or_default() .collect(); - run_command(space_id, route, options, token).await?; + run_command(space_id, route, options, token, output).await?; } Some(("list", sub_matches)) => { let space_id = sub_matches.get_one::("space_id").expect("required"); @@ -34,7 +37,8 @@ fn cli() -> Command { Command::new("gr") .version(env!("CARGO_PKG_VERSION")) .about("Gradio Command Line Client") - .arg(arg!(--token "The Hugging Face Access Token")) + .arg(arg!(-t --token "The Hugging Face Access Token")) + .arg(arg!(-o --output "Output directory, if specified, files will be saved to this directory")) .subcommand_required(true) .arg_required_else_help(true) .subcommand( @@ -58,6 +62,7 @@ async fn run_command( route: &str, options: Vec<&String>, token: Option<&String>, + outdir: Option<&String>, ) -> Result<()> { let route = format!("/{}", route.trim_start_matches('/')); @@ -105,6 +110,7 @@ async fn run_command( } } + let http_client = client.http_client.clone(); let output = client.predict(&route, data).await?; for (i, ret) in endpoint.returns.iter().enumerate() { let value = output.get(i).expect("Missing return value"); @@ -117,11 +123,15 @@ async fn run_command( }; if value.is_file() { - println!( - "{}: {}", - name, - value.clone().as_file()?.url.unwrap_or("".to_string()) - ); + let file = value.clone().as_file()?; + if let Some(outdir) = outdir { + let mut fp = PathBuf::from(outdir); + fp.push(format!("{}.{}", name, file.suggest_extension())); + file.save_to_path(&fp, Some(http_client.clone())).await?; + println!("{}: {}", name, fp.display()); + } else { + println!("{}: {}", name, file.url.unwrap_or("".to_string())); + } } else { println!("{}: {}", name, value.clone().as_value()?); }