Skip to content

Commit

Permalink
add system prompt for Gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
jelni committed Aug 27, 2024
1 parent 2042911 commit 189b089
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 18 deletions.
13 changes: 9 additions & 4 deletions src/apis/makersuite.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::time::Duration;
use std::{env, fmt};

Expand Down Expand Up @@ -122,18 +123,19 @@ pub async fn upload_file(
struct GenerateContentRequest<'a> {
contents: &'a [Content<'a>],
safety_settings: &'static [SafetySetting],
system_instruction: Option<Content<'a>>,
generation_config: GenerationConfig,
}

#[derive(Serialize)]
struct Content<'a> {
parts: &'a [Part],
parts: &'a [Part<'a>],
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub enum Part {
Text(String),
pub enum Part<'a> {
Text(Cow<'a, str>),
FileData(FileData),
}

Expand Down Expand Up @@ -225,7 +227,8 @@ pub async fn stream_generate_content(
http_client: reqwest::Client,
tx: mpsc::UnboundedSender<Result<GenerateContentResponse, GenerationError>>,
model: &str,
parts: &[Part],
parts: &[Part<'_>],
system_instruction: Option<&[Part<'_>]>,
max_output_tokens: u16,
) {
let url = format!(
Expand All @@ -251,6 +254,8 @@ pub async fn stream_generate_content(
},
SafetySetting { category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_NONE" },
],
system_instruction: system_instruction
.map(|system_instruction| Content { parts: system_instruction }),
generation_config: GenerationConfig { max_output_tokens },
})
.send()
Expand Down
51 changes: 37 additions & 14 deletions src/commands/makersuite.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::fmt::Write;
use std::time::Duration;

Expand All @@ -18,6 +19,9 @@ use crate::utilities::file_download::MEBIBYTE;
use crate::utilities::rate_limit::RateLimiter;
use crate::utilities::telegram_utils;

const SYSTEM_INSTRUCTION: &str =
"Be concise and precise. Don't be verbose. Answer in the user's language.";

pub struct GoogleGemini;

#[async_trait]
Expand All @@ -34,18 +38,12 @@ impl CommandTrait for GoogleGemini {
RateLimiter::new(3, 45)
}

#[allow(clippy::too_many_lines)]
async fn execute(&self, ctx: &CommandContext, arguments: String) -> CommandResult {
let prompt = Option::<StringGreedyOrReply>::convert(ctx, &arguments).await?.0;

ctx.send_typing().await?;
let mut model = "gemini-1.0-pro-latest";
let mut parts = Vec::new();

if let Some(prompt) = prompt {
parts.push(Part::Text(prompt.0));
}

if let Some(message_image) =
let (model, system_instruction, parts) = if let Some(message_image) =
telegram_utils::get_message_or_reply_attachment(&ctx.message, true, ctx.client_id)
.await?
{
Expand All @@ -68,19 +66,44 @@ impl CommandTrait for GoogleGemini {
)
.await?;

let mut parts = if let Some(prompt) = prompt {
vec![Part::Text(Cow::Owned(prompt.0))]
} else {
Vec::new()
};

parts.push(Part::FileData(FileData { file_uri: file.uri }));
model = "gemini-1.5-flash-latest";
}

if parts.is_empty() {
return Err(CommandError::Custom("no prompt or file provided.".into()));
}
(
"gemini-1.5-flash-latest",
Some([Part::Text(Cow::Borrowed(SYSTEM_INSTRUCTION))].as_slice()),
parts,
)
} else {
let mut parts = vec![Part::Text(Cow::Borrowed(SYSTEM_INSTRUCTION))];

if let Some(prompt) = prompt {
parts.push(Part::Text(Cow::Owned(prompt.0)));
} else {
return Err(CommandError::Custom("no prompt or file provided.".into()));
}

("gemini-1.0-pro-latest", None, parts)
};

let http_client = ctx.bot_state.http_client.clone();
let (tx, mut rx) = mpsc::unbounded_channel();

tokio::spawn(async move {
makersuite::stream_generate_content(http_client, tx, model, &parts, 512).await;
makersuite::stream_generate_content(
http_client,
tx,
model,
&parts,
system_instruction,
512,
)
.await;
});

let mut next_update = Instant::now() + Duration::from_secs(5);
Expand Down

0 comments on commit 189b089

Please sign in to comment.