Skip to content

Commit

Permalink
feat: add config serve_addr & env $SERVE_ADDR for specifying serve …
Browse files Browse the repository at this point in the history
…addr (#839)
  • Loading branch information
sigoden authored Sep 5, 2024
1 parent b16913f commit 791b615
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
14 changes: 14 additions & 0 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ pub const TEMP_SESSION_NAME: &str = "temp";

const CLIENTS_FIELD: &str = "clients";

const SERVE_ADDR: &str = "127.0.0.1:8000";

const SUMMARIZE_PROMPT: &str =
"Summarize the discussion briefly in 200 words or less to use as a prompt for future context.";
const SUMMARY_PROMPT: &str = "This is a summary of the chat history as a recap: ";
Expand Down Expand Up @@ -126,6 +128,8 @@ pub struct Config {
pub left_prompt: Option<String>,
pub right_prompt: Option<String>,

pub serve_addr: Option<String>,

pub clients: Vec<ClientConfig>,

#[serde(skip)]
Expand Down Expand Up @@ -191,6 +195,8 @@ impl Default for Config {
left_prompt: None,
right_prompt: None,

serve_addr: None,

clients: vec![],

role: None,
Expand Down Expand Up @@ -392,6 +398,10 @@ impl Config {
flags
}

pub fn serve_addr(&self) -> String {
self.serve_addr.clone().unwrap_or_else(|| SERVE_ADDR.into())
}

pub fn log(is_serve: bool) -> Result<(LevelFilter, Option<PathBuf>)> {
let log_level = env::var(get_env_name("log_level"))
.ok()
Expand Down Expand Up @@ -1848,6 +1858,10 @@ impl Config {
if let Some(v) = read_env_value::<String>("right_prompt") {
self.right_prompt = v;
}

if let Some(v) = read_env_value::<String>("serve_addr") {
self.serve_addr = v;
}
}

fn load_functions(&mut self) -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ impl Repl {
None => println!("Usage: .prompt <text>..."),
},
".role" => match args {
Some(args) => match args.split_once(|c| c == '\n' || c == ' ') {
Some(args) => match args.split_once(['\n', ' ']) {
Some((name, text)) => {
let role = self.config.read().retrieve_role(name.trim())?;
let input = Input::from_str(&self.config, text.trim(), Some(role));
Expand Down
3 changes: 1 addition & 2 deletions src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use tokio::{
use tokio_graceful::Shutdown;
use tokio_stream::wrappers::UnboundedReceiverStream;

const DEFAULT_ADDRESS: &str = "127.0.0.1:8000";
const DEFAULT_MODEL_NAME: &str = "default";
const PLAYGROUND_HTML: &[u8] = include_bytes!("../assets/playground.html");
const ARENA_HTML: &[u8] = include_bytes!("../assets/arena.html");
Expand All @@ -50,7 +49,7 @@ pub async fn run(config: GlobalConfig, addr: Option<String>) -> Result<()> {
addr
}
}
None => DEFAULT_ADDRESS.to_string(),
None => config.read().serve_addr(),
};
let server = Arc::new(Server::new(&config));
let listener = TcpListener::bind(&addr).await?;
Expand Down

0 comments on commit 791b615

Please sign in to comment.