Skip to content

Commit

Permalink
feat: 内部で使う画像生成系の AI の API の wrapper を実装する (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
nanai10a authored Jan 2, 2024
1 parent 47468aa commit 5775763
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ envy = { version = "0.4" }
typed-builder = { version = "0.18" }
reqwest = { version = "0.11" }
serde_json = { version = "1.0" }
base64 = { version = "0.21" }

[dependencies.serenity]
version = "0.12"
Expand Down
116 changes: 116 additions & 0 deletions src/dalle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// ref: https://platform.openai.com/docs/api-reference/images
pub struct OpenAi<Model> {
http: reqwest::Client,
token: String,
model: core::marker::PhantomData<Model>,
}

trait Model {
const NAME: &'static str;
}

macro_rules! define_model {
($vis:vis $name:ident : $model:expr) => {
$vis struct $name;

impl Model for $name {
const NAME: &'static str = $model;
}
};
}

define_model!(pub DallE2: "dall-e-2");
define_model!(pub DallE3: "dall-e-3");

impl<Model> OpenAi<Model> {
pub fn new(token: impl AsRef<str>) -> Self {
Self {
http: reqwest::Client::new(),
token: token.as_ref().to_owned(),
model: core::marker::PhantomData,
}
}
}

impl<Model: self::Model + Send + Sync> super::Image for OpenAi<Model> {
async fn create(
&self,
prompt: impl AsRef<str> + Send + Sync,
) -> anyhow::Result<super::GeneratedImage> {
let req = Request {
prompt: prompt.as_ref(),
model: Model::NAME,
response_format: "b64_json",
};

let res = self
.http
.post("https://api.openai.com/v1/images/generations")
.bearer_auth(&self.token)
.header(reqwest::header::CONTENT_TYPE, "application/json")
.body(serde_json::to_vec(&req)?)
.send()
.await?
.error_for_status()?;

if res.status() != reqwest::StatusCode::OK {
anyhow::bail!("unexpected status code: {}", res.status());
}

let res = res.bytes().await?;
let res = serde_json::from_slice::<Response>(&res)?;

let [image] = res.data;
assert!(image.b64_json.is_png());

Ok(super::GeneratedImage {
image: image.b64_json.0,
prompt: image.revised_prompt.unwrap_or(prompt.as_ref()).to_owned(),
ext: super::ImageExt::Png,
})
}
}

#[derive(serde::Serialize)]
struct Request<'a> {
prompt: &'a str,
model: &'a str,
response_format: &'a str,
}

#[derive(serde::Deserialize)]
struct Response<'a> {
#[serde(borrow)]
data: [Image<'a>; 1],
}

#[derive(serde::Deserialize)]
struct Image<'a> {
b64_json: Base64Image,
revised_prompt: Option<&'a str>,
}

#[derive(serde::Deserialize)]
#[serde(transparent)]
struct Base64Image(#[serde(with = "base64")] Vec<u8>);

impl Base64Image {
fn is_png(&self) -> bool {
self.0
.starts_with(&[0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a])
}
}

mod base64 {
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: serde::Deserializer<'de>,
{
use base64::Engine as _;
use serde::Deserialize as _;

base64::engine::general_purpose::STANDARD
.decode(<&str>::deserialize(deserializer)?)
.map_err(serde::de::Error::custom)
}
}
31 changes: 31 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,34 @@ mod openai;
pub use gemini::Gemini;
pub type OpenAiGPT4Turbo = openai::OpenAi<openai::GPT4Turbo>;
pub type OpenAiGPT35Turbo = openai::OpenAi<openai::GPT35Turbo>;

pub trait Image {
fn create(
&self,
prompt: impl AsRef<str> + Send + Sync,
) -> impl Future<Output = anyhow::Result<GeneratedImage>> + Send + Sync;
}

pub struct GeneratedImage {
pub image: Vec<u8>,
pub prompt: String,
pub ext: ImageExt,
}

#[non_exhaustive]
pub enum ImageExt {
Png,
}

impl ImageExt {
pub fn as_str(&self) -> &str {
match self {
Self::Png => "PNG",
}
}
}

mod dalle;

pub type OpenAiDallE2 = dalle::OpenAi<dalle::DallE2>;
pub type OpenAiDallE3 = dalle::OpenAi<dalle::DallE3>;

0 comments on commit 5775763

Please sign in to comment.