diff --git a/Cargo.lock b/Cargo.lock index 8548382..63fdd98 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1420,6 +1420,7 @@ dependencies = [ "thiserror", "tokio", "tokio-util", + "urlencoding", "warp", "zstd", ] @@ -2173,6 +2174,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf-8" version = "0.7.6" diff --git a/Cargo.toml b/Cargo.toml index 89782bb..82e387d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ warp = "*" hyper = "0" http = "0" reqwest = "*" -# wait for a release containing krustlet/oci-distribution#57 and #71 +# wait for a release containing krustlet/oci-distribution#57, #71, and #90 oci-distribution = { git = "https://github.com/krustlet/oci-distribution" } clap = { version = "*", features = [ "cargo", "derive" ] } clap_complete = "*" @@ -37,3 +37,4 @@ nix-base32 = "*" sha2 = "*" tempfile = "*" ed25519-compact = "*" +urlencoding = "*" diff --git a/src/convert.rs b/src/convert.rs index 57f2080..910a0e5 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -1,15 +1,260 @@ +use clap::Parser; +use clap::ValueEnum; use data_encoding::BASE32_DNSSEC; +use once_cell::sync::Lazy; +use std::collections::BTreeMap; use crate::error::Error; -pub fn key_to_tag(key: &str) -> String { +#[derive(Clone, Debug, Parser)] +pub struct EncodingOptions { + #[arg(long, value_enum, default_value = "custom")] + pub tag_encoding: TagEncoding, + #[arg(long)] + pub fallback_encodings: Vec, +} + +impl EncodingOptions { + pub fn key_to_tag(&self, key: &str) -> (String, Vec) { + let main = self.tag_encoding.key_to_tag(key); + let fallbacks = self + .fallback_encodings + .iter() + .map(|e| e.key_to_tag(key)) + .collect(); + (main, fallbacks) + } + pub fn tag_to_key(&self, tag: &str) -> Result { + let mut errors = vec![]; + let main = [self.tag_encoding]; + let encodings = main.iter().chain(self.fallback_encodings.iter()); + for e in encodings { + match e.tag_to_key(tag) { + Ok(r) => return Ok(r), + Err(e) => errors.push(e), + } + } + Err(Error::TagToKey(errors)) + } +} + +#[derive(Clone, Copy, Debug, ValueEnum)] +pub enum TagEncoding { + // A custom encoding + Custom, // https://docs.rs/data-encoding/latest/data_encoding/constant.BASE32_DNSSEC.html // It uses a base32 extended hex alphabet. // It is case-insensitive when decoding and uses lowercase when encoding. // It does not use padding. - BASE32_DNSSEC.encode(key.as_bytes()) + Base32DNSSEC, } -pub fn tag_to_key(tag: &str) -> Result { - Ok(String::from_utf8(BASE32_DNSSEC.decode(tag.as_bytes())?)?) +static CUSTOM_ENCODING: Lazy = Lazy::new(CustomEncoding::new); + +impl TagEncoding { + pub fn key_to_tag(&self, key: &str) -> String { + match self { + TagEncoding::Custom => CUSTOM_ENCODING.encode(key), + TagEncoding::Base32DNSSEC => BASE32_DNSSEC.encode(key.as_bytes()), + } + } + + pub fn tag_to_key(&self, tag: &str) -> Result { + match self { + TagEncoding::Custom => CUSTOM_ENCODING.decode(tag), + TagEncoding::Base32DNSSEC => { + Ok(String::from_utf8(BASE32_DNSSEC.decode(tag.as_bytes())?)?) + } + } + } +} + +/// A tag MUST be at most 128 characters in length and MUST match the following regular expression: +/// [a-zA-Z0-9_][a-zA-Z0-9._-]{0,127} +/// https://github.com/opencontainers/distribution-spec/blob/main/spec.md +#[derive(Clone, Debug)] +pub struct CustomEncoding { + symbol_table: Vec, + reverse_table: BTreeMap, +} + +impl Default for CustomEncoding { + fn default() -> Self { + Self::new() + } +} + +impl CustomEncoding { + pub fn new() -> CustomEncoding { + let mut symbol_table = Vec::new(); + symbol_table.extend('0'..='9'); + symbol_table.extend('a'..='z'); + symbol_table.extend('A'..='Z'); + symbol_table.push('-'); + symbol_table.push('.'); + + let mut reverse_table = BTreeMap::new(); + + for (i, c) in symbol_table.iter().enumerate() { + reverse_table.insert(*c, i as u32); + } + + CustomEncoding { + symbol_table, + reverse_table, + } + } + + pub fn encode(&self, key: &str) -> String { + let mut result = String::new(); + let mut first = true; + for c in key.chars() { + match c { + 'a'..='z' | 'A'..='Z' | '0'..='9' => result.push(c), + '-' | '.' => { + if first { + self.encode_char(&mut result, c); + } else { + result.push(c); + } + } + _ => self.encode_char(&mut result, c), + } + first = false; + } + result + } + + fn encode_char(&self, result: &mut String, c: char) { + result.push('_'); + + let mut n: u32 = c.into(); + + let mut char_code = Vec::new(); + let base = self.symbol_table.len() as u32; + while n != 0 { + let quotient = n / base; + let remainder = n % base; + + char_code.push(self.symbol_table[remainder as usize]); + + n = quotient; + } + result.extend(char_code.iter().rev()); + + result.push('_'); + } + + pub fn decode(&self, tag: &str) -> Result { + let mut chars = tag.chars(); + let mut result = String::new(); + while let Some(c) = chars.next() { + match c { + 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '.' => result.push(c), + '_' => { + let mut encoded_char = Vec::new(); + loop { + match chars.next() { + Some('_') => break, + Some(n) => encoded_char.push(n), + None => return Err(Error::InvalidTag(tag.to_string())), + } + } + result.push( + self.decode_char(&encoded_char) + .ok_or_else(|| Error::InvalidTag(tag.to_string()))?, + ) + } + _ => return Err(Error::InvalidTag(tag.to_string())), + } + } + Ok(result) + } + + fn decode_char(&self, encoded: &[char]) -> Option { + let base = self.symbol_table.len() as u32; + let mut n = 0u32; + for c in encoded.iter() { + n = n.checked_mul(base)?; + n = n.checked_add(*self.reverse_table.get(c)?)?; + } + n.try_into().ok() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn custom_encode_symbol_table_length() { + assert_eq!(CUSTOM_ENCODING.symbol_table.len(), 64); + } + + #[test] + fn custom_encode_symbol_table_validate() { + assert_eq!( + CUSTOM_ENCODING.symbol_table.len(), + CUSTOM_ENCODING.reverse_table.len() + ); + for (i, c) in CUSTOM_ENCODING.symbol_table.iter().enumerate() { + assert_eq!(CUSTOM_ENCODING.reverse_table[c], i as u32); + } + for (c, i) in CUSTOM_ENCODING.reverse_table.iter() { + assert_eq!(CUSTOM_ENCODING.symbol_table[*i as usize], *c); + } + } + + #[test] + fn custom_encode_id() { + assert_eq!( + CUSTOM_ENCODING + .encode("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-."), + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-." + ); + } + + #[test] + fn custom_encode_first_special() { + assert_eq!(CUSTOM_ENCODING.encode("--"), "_J_-"); + assert_eq!(CUSTOM_ENCODING.encode(".."), "_K_."); + assert_eq!(CUSTOM_ENCODING.encode("//"), "_L__L_"); + assert_eq!(CUSTOM_ENCODING.encode("__"), "_1v__1v_"); + } + + #[test] + fn custom_decode_id() { + assert_eq!( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-.", + CUSTOM_ENCODING + .decode("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-.") + .unwrap() + ); + } + + #[test] + fn custom_decode_first_special() { + assert_eq!(("--"), CUSTOM_ENCODING.decode("_J_-").unwrap()); + assert_eq!((".."), CUSTOM_ENCODING.decode("_K_.").unwrap()); + assert_eq!(("//"), CUSTOM_ENCODING.decode("_L__L_").unwrap()); + assert_eq!(("__"), CUSTOM_ENCODING.decode("_1v__1v_").unwrap()); + } + + #[test] + fn custom_encode_decode() { + let test_strings = [ + "test", + "测试", + "_test-测试_", + "._test-测试_.", + "._test-测试_.测试", + "realisations/sha256:67890e0958e5d1a2944a3389151472a9acde025c7812f68381a7eef0d82152d1!libgcc.doi" + ]; + for s in test_strings { + assert_eq!( + CUSTOM_ENCODING.decode(&CUSTOM_ENCODING.encode(s)).unwrap(), + s + ); + } + } } diff --git a/src/error.rs b/src/error.rs index 9f62a44..b68456c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,13 +12,17 @@ pub enum Error { Http(#[from] http::Error), #[error("decode error: {0}")] Decode(#[from] data_encoding::DecodeError), + #[error("tag-to-key error: {0:?}")] + TagToKey(Vec), + #[error("invalid tag error: {0}")] + InvalidTag(String), #[error("from utf-8 error: {0}")] FromUtf8(#[from] FromUtf8Error), #[error("invalid authorization header: {0}")] InvalidAuthorization(String), #[error("oci distribution error: {0}")] OciDistribution(#[from] OciDistributionError), - #[error("invalid imag`e layer count: {0}")] + #[error("invalid image layer count: {0}")] InvalidLayerCount(usize), #[error("invalid image layer media type: {0}")] InvalidLayerMediaType(String), @@ -85,6 +89,8 @@ impl Error { match self { Error::Http(_) => StatusCode::INTERNAL_SERVER_ERROR, Error::Decode(_) => StatusCode::BAD_REQUEST, + Error::TagToKey(_) => StatusCode::BAD_REQUEST, + Error::InvalidTag(_) => StatusCode::BAD_REQUEST, Error::FromUtf8(_) => StatusCode::BAD_REQUEST, Error::InvalidAuthorization(_) => StatusCode::BAD_REQUEST, Error::OciDistribution(_) => StatusCode::BAD_REQUEST, diff --git a/src/key.rs b/src/key.rs index c92cd37..0b47cb9 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,13 +1,24 @@ -use crate::{ - convert::{key_to_tag, tag_to_key}, - error::Error, - options::TagCommands, -}; +use crate::{error::Error, options::TagCommands}; pub async fn key_main(command: TagCommands) -> Result<(), Error> { match command { - TagCommands::Encode { key } => print!("{}", key_to_tag(&key)), - TagCommands::Decode { tag } => print!("{}", tag_to_key(&tag)?), + TagCommands::Encode { + key, + fallbacks, + encoding_options, + } => { + let (m, f) = encoding_options.key_to_tag(&key); + println!("{}", m); + if fallbacks { + for tag in f { + println!("{}", tag); + } + } + } + TagCommands::Decode { + tag, + encoding_options, + } => println!("{}", encoding_options.tag_to_key(&tag)?), } Ok(()) } diff --git a/src/options.rs b/src/options.rs index ce287a6..e21b2a5 100644 --- a/src/options.rs +++ b/src/options.rs @@ -1,8 +1,11 @@ use clap::{Parser, Subcommand}; use regex::Regex; +use reqwest::Url; use std::net::SocketAddr; +use crate::convert::EncodingOptions; + #[derive(Clone, Debug, Parser)] #[command(author, version, about, long_about = None)] pub struct Options { @@ -28,15 +31,39 @@ pub struct ServerOptions { pub max_retry: usize, #[arg(long, help = "disable ssl")] pub no_ssl: bool, + #[arg(short, long, value_name = "URL", help = "upstream cache URLs")] + pub upstream: Vec, + #[arg( + short, + long, + value_name = "PATTERN", + default_value = "nix-cache-info", + help = "ignored file matched when querying upstream" + )] + pub ignore_upstream: Regex, + #[arg(long, help = "upstream anonymous queries")] + pub upstream_anonymous: bool, + #[clap(flatten)] + pub encoding_options: EncodingOptions, } #[derive(Clone, Debug, Subcommand)] #[command(about = "Command line tools for tag-key conversion")] pub enum TagCommands { #[command(about = "Encode a key to tag")] - Encode { key: String }, + Encode { + key: String, + #[arg(long)] + fallbacks: bool, + #[clap(flatten)] + encoding_options: EncodingOptions, + }, #[command(about = "Decode a tag to key")] - Decode { tag: String }, + Decode { + tag: String, + #[clap(flatten)] + encoding_options: EncodingOptions, + }, } #[derive(Clone, Debug, Parser)] @@ -71,6 +98,8 @@ pub struct PushOptions { pub allow_immutable_db: bool, #[arg(long, help = "disable ssl")] pub no_ssl: bool, + #[clap(flatten)] + pub encoding_options: EncodingOptions, #[command(subcommand)] pub subcommand: Option, } diff --git a/src/registry.rs b/src/registry.rs index 580fa9f..785b17c 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -1,5 +1,10 @@ use std::fmt; +use crate::convert::EncodingOptions; +use crate::{ + error::Error, + options::{PushOptions, ServerOptions}, +}; use maplit::hashmap; use oci_distribution::{ client::{ClientConfig, ClientProtocol, Config, ImageLayer}, @@ -10,12 +15,6 @@ use oci_distribution::{ Client, Reference, }; -use crate::{ - convert::key_to_tag, - error::Error, - options::{PushOptions, ServerOptions}, -}; - pub const LAYER_MEDIA_TYPE: &str = "application/octet-stream"; pub const CONTENT_TYPE_ANNOTATION: &str = "com.linyinfeng.oranc.content.type"; @@ -30,6 +29,7 @@ pub struct RegistryOptions { pub no_ssl: bool, pub dry_run: bool, pub max_retry: usize, + pub encoding_options: EncodingOptions, } #[derive(Debug, Clone)] @@ -47,17 +47,26 @@ pub struct OciItem { #[derive(Debug, Clone)] pub struct LayerInfo { + pub reference: Reference, pub digest: String, pub content_type: String, } impl OciLocation { - pub fn reference(&self) -> Reference { - Reference::with_tag( - self.registry.clone(), - self.repository.clone(), - key_to_tag(&self.key), - ) + pub fn reference(&self, encoding_options: &EncodingOptions) -> (Reference, Vec) { + let build_ref = + |tag| Reference::with_tag(self.registry.clone(), self.repository.clone(), tag); + let (main_tag, fallback_tags) = encoding_options.key_to_tag(&self.key); + let main = build_ref(main_tag); + let fallback_refs = fallback_tags.into_iter().map(build_ref).collect(); + (main, fallback_refs) + } + + pub fn references_merged(&self, encoding_options: &EncodingOptions) -> Vec { + let (main, fallbacks) = self.reference(encoding_options); + let mut result = vec![main]; + result.extend(fallbacks); + result } } @@ -70,35 +79,54 @@ pub async fn get_layer_info( return Err(Error::InvalidMaxRetry(max_retry)); } - let reference = location.reference(); + let references = location.references_merged(&ctx.options.encoding_options); let mut pull_result = None; let mut errors = vec![]; - for attempt in 1..max_retry { - log::debug!("pull image manifest {reference:?}, attempt {attempt}/{max_retry}"); - match ctx.client.pull_image_manifest(&reference, &ctx.auth).await { - Ok(r) => pull_result = Some(r), - Err(OciDistributionError::ImageManifestNotFoundError(_)) => return Ok(None), - Err(OciDistributionError::RegistryError { envelope, .. }) - if envelope - .errors - .iter() - .all(|e| e.code == OciErrorCode::ManifestUnknown) => - { - return Ok(None) - } - Err(oci_error) => { - let e = oci_error.into(); - log::warn!( - "pull image manifest {reference:?}, attempt {attempt}/{max_retry} failed: {}", - e - ); - errors.push(e); + 'fallbacks: for reference in references { + let mut ref_errors = vec![]; + 'retries: for attempt in 1..max_retry { + log::debug!("pull image manifest {reference:?}, attempt {attempt}/{max_retry}"); + match ctx.client.pull_image_manifest(&reference, &ctx.auth).await { + Ok(res) => { + pull_result = Some((reference.clone(), res)); + break 'fallbacks; + } + Err(OciDistributionError::ImageManifestNotFoundError(_)) => break 'retries, + Err(OciDistributionError::RegistryError { envelope, .. }) + if envelope + .errors + .iter() + .all(|e| e.code == OciErrorCode::ManifestUnknown) => + { + break 'retries + } + Err(oci_error) => { + let e = oci_error.into(); + log::warn!( + "pull image manifest {reference:?}, attempt {attempt}/{max_retry} failed: {}", + e + ); + ref_errors.push(e); + } } } + if ref_errors.len() == max_retry { + log::error!("pull image manifest {reference:?} failed"); + // all reties failed + errors.extend(ref_errors); + } } - let (manifest, _hash) = match pull_result { + let (reference, (manifest, _hash)) = match pull_result { Some(r) => r, - None => return Err(Error::RetryAllFails(errors)), + None => { + if errors.is_empty() { + // all reference not found + return Ok(None); + } else { + // at least one reference failed + return Err(Error::RetryAllFails(errors)); + } + } }; match manifest.layers.len() { @@ -124,6 +152,7 @@ pub async fn get_layer_info( } }; let info = LayerInfo { + reference, digest: layer_manifest.digest.clone(), content_type: content_type.clone(), }; @@ -181,7 +210,7 @@ pub async fn put( if max_retry < 1 { return Err(Error::InvalidMaxRetry(max_retry)); } - let reference = location.reference(); + let (reference, _fallbacks) = location.reference(&ctx.options.encoding_options); let mut errors = vec![]; for attempt in 1..max_retry { log::debug!("push {reference:?}, attempt {attempt}/{max_retry}"); @@ -220,6 +249,7 @@ impl RegistryOptions { dry_run: options.dry_run, max_retry: options.max_retry, no_ssl: options.no_ssl, + encoding_options: options.encoding_options.clone(), } } @@ -228,6 +258,7 @@ impl RegistryOptions { dry_run: false, max_retry: options.max_retry, no_ssl: options.no_ssl, + encoding_options: options.encoding_options.clone(), } } diff --git a/src/server.rs b/src/server.rs index ed48d1b..1ed0a9d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,10 +1,14 @@ use crate::error::Error; +use crate::registry; use crate::registry::get_layer_info; use crate::registry::LayerInfo; +use crate::registry::OciItem; use crate::registry::OciLocation; - use crate::registry::RegistryOptions; +pub mod upstream; + +use bytes::Bytes; use data_encoding::BASE64; use http::header; use http::Response; @@ -18,6 +22,7 @@ use warp::{Filter, Rejection, Reply}; use crate::options::ServerOptions; +const OK_RESPONSE_BODY: &str = "<_/>"; const NO_SUCH_KEY_RESPONSE_BODY: &str = "NoSuchKey"; static AWS_AUTH_PATTERN: Lazy = @@ -26,18 +31,23 @@ static BASIC_AUTH_PATTERN: Lazy = Lazy::new(|| Regex::new("^Basic (.*)$") static DECODED_PATTERN: Lazy = Lazy::new(|| Regex::new("^([^:]+):(.+)$").unwrap()); #[derive(Debug, Clone)] -struct ServerContext { +pub struct ServerContext { options: ServerOptions, + http_client: reqwest::Client, } -async fn get( +pub async fn get( ctx: ServerContext, location: OciLocation, auth: RegistryAuth, ) -> Result, Rejection> { log::info!("get: {location}"); + if let Some(response) = upstream::check_and_redirect(&ctx, &location.key, &auth).await? { + return Ok(response); + } let mut registry_ctx = RegistryOptions::from_server_options(&ctx.options).context(auth); let LayerInfo { + reference, digest, content_type, } = get_layer_info(&mut registry_ctx, &location) @@ -45,7 +55,7 @@ async fn get( .ok_or(Error::ReferenceNotFound(location.clone()))?; let blob_stream = registry_ctx .client - .pull_blob_stream(&location.reference(), &digest) + .pull_blob_stream(&reference, &digest) .await .map_err(Error::OciDistribution)?; Ok(Response::builder() @@ -55,14 +65,18 @@ async fn get( .map_err(Error::Http)?) } -async fn head( +pub async fn head( ctx: ServerContext, location: OciLocation, auth: RegistryAuth, ) -> Result, Rejection> { log::info!("head: {location}"); + if let Some(response) = upstream::check_and_redirect(&ctx, &location.key, &auth).await? { + return Ok(response); + } let mut registry_ctx = RegistryOptions::from_server_options(&ctx.options).context(auth); let LayerInfo { + reference: _, digest: _, content_type, } = get_layer_info(&mut registry_ctx, &location) @@ -75,11 +89,32 @@ async fn head( .map_err(Error::Http)?) } -fn registry_auth() -> impl Filter + Copy { +pub async fn put( + ctx: ServerContext, + location: OciLocation, + auth: RegistryAuth, + content_type: Option, + body: Bytes, +) -> Result, Rejection> { + log::info!("put: {location}"); + // on upstream query for put + let mut registry_ctx = RegistryOptions::from_server_options(&ctx.options).context(auth); + let item = OciItem { + content_type, + data: body.to_vec(), + }; + registry::put(&mut registry_ctx, &location, item).await?; + Ok(Response::builder() + .status(StatusCode::OK) + .body(OK_RESPONSE_BODY) + .map_err(Error::Http)?) +} + +pub fn registry_auth() -> impl Filter + Copy { warp::header::optional("authorization").and_then(parse_auth) } -async fn parse_auth(opt: Option) -> Result { +pub async fn parse_auth(opt: Option) -> Result { match opt { None => Ok(RegistryAuth::Anonymous), Some(original) => { @@ -102,30 +137,39 @@ async fn parse_auth(opt: Option) -> Result { } } -fn oci_location() -> impl Filter + Copy { +pub fn oci_location() -> impl Filter + Copy { warp::path::param() // registry .and(warp::path::param()) // repository part1 .and(warp::path::param()) // repository part1 .and(warp::path::tail()) // key - .map( - |registry, rep1: String, rep2: String, tail: warp::path::Tail| { - let repository = format!("{rep1}/{rep2}"); - let key = tail.as_str(); - OciLocation { - registry, - repository, - key: key.to_owned(), - } - }, - ) + .and_then(convert_to_oci_location) } -async fn handle_error(rejection: Rejection) -> Result { +pub async fn convert_to_oci_location( + registry: String, + rep1: String, + rep2: String, + tail: warp::path::Tail, +) -> Result { + let tail_str = tail.as_str(); + let decoded_registry = urlencoding::decode(®istry).map_err(Error::FromUtf8)?; + let decoded_rep1 = urlencoding::decode(&rep1).map_err(Error::FromUtf8)?; + let decoded_rep2 = urlencoding::decode(&rep2).map_err(Error::FromUtf8)?; + let decoded_tail = urlencoding::decode(tail_str).map_err(Error::FromUtf8)?; + let repository = format!("{decoded_rep1}/{decoded_rep2}"); + Ok(OciLocation { + registry: decoded_registry.to_string(), + repository, + key: decoded_tail.to_string(), + }) +} + +pub async fn handle_error(rejection: Rejection) -> Result { log::trace!("handle rejection: {rejection:?}"); let code; let message; if let Some(e) = rejection.find::() { - log::info!("handle error: {e}"); + log::debug!("handle error: {e}"); code = e.code(); match e { // otherwise aws clients can not decode 404 error message @@ -141,13 +185,17 @@ async fn handle_error(rejection: Rejection) -> Result { Ok(warp::reply::with_status(message, code)) } -async fn log_rejection(rejection: Rejection) -> Result, Rejection> { +pub async fn log_rejection(rejection: Rejection) -> Result, Rejection> { log::debug!("unhandled rejection: {rejection:?}"); Err(rejection) } pub async fn server_main(options: ServerOptions) -> Result<(), Error> { - let ctx = ServerContext { options }; + let http_client = reqwest::Client::new(); + let ctx = ServerContext { + options, + http_client, + }; let ctx_filter = { let ctx = ctx.clone(); @@ -164,6 +212,12 @@ pub async fn server_main(options: ServerOptions) -> Result<(), Error> { .or(warp::head() .and(common()) .and_then(head) + .recover(handle_error)) + .or(warp::put() + .and(common()) + .and(warp::header::optional("content-type")) + .and(warp::body::bytes()) + .and_then(put) .recover(handle_error)); let log = warp::log::custom(|info| { diff --git a/src/server/upstream.rs b/src/server/upstream.rs new file mode 100644 index 0000000..f6c3311 --- /dev/null +++ b/src/server/upstream.rs @@ -0,0 +1,88 @@ +use std::path::PathBuf; + +use http::{Response, StatusCode}; +use hyper::Body; +use oci_distribution::secrets::RegistryAuth; +use reqwest::Url; +use warp::Rejection; + +use crate::error::Error; + +use super::ServerContext; + +pub async fn check_and_redirect( + ctx: &ServerContext, + key: &str, + auth: &RegistryAuth, +) -> Result>, Rejection> { + match check(ctx, key, auth).await? { + Some(url) => Ok(Some(redirect_response(key, &url)?)), + None => Ok(None), + } +} + +pub async fn check( + ctx: &ServerContext, + key: &str, + auth: &RegistryAuth, +) -> Result, Error> { + let max_retry = ctx.options.max_retry; + if max_retry < 1 { + return Err(Error::InvalidMaxRetry(max_retry)); + } + + if let RegistryAuth::Anonymous = auth { + // skip check upstream caches if `--upstream-anonymous` is off + if !ctx.options.upstream_anonymous { + log::debug!("skipped checking upstream for key: '{}'", key); + return Ok(None); + } + } + if ctx.options.ignore_upstream.is_match(key) { + return Ok(None); + } + for upstream in &ctx.options.upstream { + let url = upstream_url(upstream, key)?; + for attempt in 1..max_retry { + let response = ctx + .http_client + .head(url.clone()) + .send() + .await + .map_err(Error::Reqwest)?; + if response.status() == StatusCode::OK { + return Ok(Some(url)); + } else if response.status() == StatusCode::NOT_FOUND { + break; + } else { + log::warn!( + "query upstream url '{url}', attempt {attempt}/{max_retry} failed: {:?}", + response + ); + } + } + } + Ok(None) +} + +pub fn upstream_url(base: &Url, key: &str) -> Result { + let path = base.path(); + let new_path = PathBuf::from(path).join(key); + match new_path.to_str() { + Some(p) => { + let mut upstream = base.clone(); + upstream.set_path(p); + Ok(upstream) + } + None => Err(Error::InvalidPath(new_path)), + } +} + +pub fn redirect_response(key: &str, url: &Url) -> Result, Error> { + log::info!("redirect: key = {key}, url = {url}"); + Response::builder() + .status(StatusCode::FOUND) + .header(http::header::LOCATION, url.to_string()) + .body(Body::empty()) + .map_err(Error::Http) +}