From 49273986cc7a280480bb659e0f24ec368f936c00 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Thu, 2 Jan 2025 16:39:21 -0800 Subject: [PATCH 1/9] Add new client method for getting metrics in json and prometheus format --- clients/python/lorax/client.py | 224 ++++++++++++++++++++++++--------- clients/python/lorax/types.py | 46 +++++-- router/src/server.rs | 91 ++++++++++++-- 3 files changed, 289 insertions(+), 72 deletions(-) diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index aa58e800d..94327fcb2 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -2,10 +2,11 @@ import logging import requests from requests.adapters import HTTPAdapter, Retry - +import copy +import os from aiohttp import ClientSession, ClientTimeout from pydantic import ValidationError -from typing import Any, Dict, Optional, List, AsyncIterator, Iterator, Union +from typing import Any, Dict, Optional, List, AsyncIterator, Iterator, Union, Literal from lorax.types import ( BatchRequest, @@ -16,10 +17,10 @@ MergedAdapters, ResponseFormat, EmbedResponse, - ClassifyResponse + ClassifyResponse, + MetricsResponse, ) from lorax.errors import parse_error -import os LORAX_DEBUG_MODE = os.getenv("LORAX_DEBUG_MODE", None) is not None if LORAX_DEBUG_MODE: @@ -28,6 +29,7 @@ # You will see the REQUEST, including HEADERS and DATA, and RESPONSE with HEADERS but without DATA. # The only thing missing will be the response.body which is not logged. import http.client as http_client + http_client.HTTPConnection.debuglevel = 1 # You must initialize logging, otherwise you'll not see debug output. @@ -87,6 +89,7 @@ def __init__( self.embed_endpoint = f"{base_url}/embed" self.classify_endpoint = f"{base_url}/classify" self.classify_batch_endpoint = f"{base_url}/classify_batch" + self.metrics_endpoint = f"{base_url}/metrics" self.headers = headers self.cookies = cookies self.timeout = timeout @@ -109,7 +112,7 @@ def _create_session(self): ) self.session.mount("https://", adapter) self.session.mount("http://", adapter) - + def _post(self, json: dict, stream: bool = False) -> requests.Response: """ Given inputs, make an HTTP POST call @@ -120,8 +123,8 @@ def _post(self, json: dict, stream: bool = False) -> requests.Response: stream (`bool`): Whether to stream the HTTP response or not - - Returns: + + Returns: requests.Response: HTTP response object """ # Instantiate session if currently None @@ -130,7 +133,7 @@ def _post(self, json: dict, stream: bool = False) -> requests.Response: # Retry if the session is stale and hits a ConnectionError current_retry_attempt = 0 - + # Make the HTTP POST request while True: try: @@ -140,18 +143,21 @@ def _post(self, json: dict, stream: bool = False) -> requests.Response: headers=self.headers, cookies=self.cookies, timeout=self.timeout, - stream=stream + stream=stream, ) return resp - except (requests.exceptions.ConnectionError, requests.exceptions.ConnectTimeout) as e: + except ( + requests.exceptions.ConnectionError, + requests.exceptions.ConnectTimeout, + ) as e: # Refresh session if there is a ConnectionError self.session = None self._create_session() - + # Raise error if retries have been exhausted if current_retry_attempt >= self.max_session_retries: raise e - + current_retry_attempt += 1 except Exception as e: # Raise any other exception @@ -219,13 +225,14 @@ def generate( top_k (`int`): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`): - If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or - higher are kept for generation. + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to + `top_p` or higher are kept for generation. truncate (`int`): Truncate inputs tokens to the given size typical_p (`float`): Typical Decoding mass - See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for + more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) response_format (`Optional[Union[Dict[str, Any], ResponseFormat]]`): @@ -242,7 +249,8 @@ def generate( decoder_input_details (`bool`): Return the decoder input token logprobs and ids return_k_alternatives (`int`): - The number of highest probability vocabulary tokens to return as alternative tokens in the generation result + The number of highest probability vocabulary tokens to return as alternative tokens in the + generation result details (`bool`): Return the token logprobs and ids for generated tokens @@ -272,7 +280,7 @@ def generate( watermark=watermark, response_format=response_format, decoder_input_details=decoder_input_details, - return_k_alternatives=return_k_alternatives + return_k_alternatives=return_k_alternatives, ) # Instantiate the request object @@ -292,8 +300,10 @@ def generate( payload = {"message": e.msg} if resp.status_code != 200: - raise parse_error(resp.status_code, payload, resp.headers if LORAX_DEBUG_MODE else None) - + raise parse_error( + resp.status_code, payload, resp.headers if LORAX_DEBUG_MODE else None + ) + if LORAX_DEBUG_MODE: print(resp.headers) @@ -356,13 +366,14 @@ def generate_stream( top_k (`int`): The number of highest probability vocabulary tokens to keep for top-k-filtering. top_p (`float`): - If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or - higher are kept for generation. + If set to < 1, only the smallest set of most probable tokens with probabilities that add up + to `top_p` or higher are kept for generation. truncate (`int`): Truncate inputs tokens to the given size typical_p (`float`): Typical Decoding mass - See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for + more information watermark (`bool`): Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) response_format (`Optional[Union[Dict[str, Any], ResponseFormat]]`): @@ -415,7 +426,11 @@ def generate_stream( ) if resp.status_code != 200: - raise parse_error(resp.status_code, resp.json(), resp.headers if LORAX_DEBUG_MODE else None) + raise parse_error( + resp.status_code, + resp.json(), + resp.headers if LORAX_DEBUG_MODE else None, + ) # Parse ServerSentEvents for byte_payload in resp.iter_lines(): @@ -434,10 +449,13 @@ def generate_stream( response = StreamResponse(**json_payload) except ValidationError: # If we failed to parse the payload, then it is an error payload - raise parse_error(resp.status_code, json_payload, resp.headers if LORAX_DEBUG_MODE else None) + raise parse_error( + resp.status_code, + json_payload, + resp.headers if LORAX_DEBUG_MODE else None, + ) yield response - def embed(self, inputs: str) -> EmbedResponse: """ Given inputs, embed the text using the model @@ -445,8 +463,8 @@ def embed(self, inputs: str) -> EmbedResponse: Args: inputs (`str`): Input text - - Returns: + + Returns: Embeddings: computed embeddings """ request = Request(inputs=inputs) @@ -461,10 +479,13 @@ def embed(self, inputs: str) -> EmbedResponse: payload = resp.json() if resp.status_code != 200: - raise parse_error(resp.status_code, resp.json(), resp.headers if LORAX_DEBUG_MODE else None) - - return EmbedResponse(**payload) + raise parse_error( + resp.status_code, + resp.json(), + resp.headers if LORAX_DEBUG_MODE else None, + ) + return EmbedResponse(**payload) def classify(self, inputs: str) -> ClassifyResponse: """ @@ -473,8 +494,8 @@ def classify(self, inputs: str) -> ClassifyResponse: Args: inputs (`str`): Input text - - Returns: + + Returns: Entities: Entities found in the input text """ request = Request(inputs=inputs) @@ -489,10 +510,14 @@ def classify(self, inputs: str) -> ClassifyResponse: payload = resp.json() if resp.status_code != 200: - raise parse_error(resp.status_code, resp.json(), resp.headers if LORAX_DEBUG_MODE else None) - + raise parse_error( + resp.status_code, + resp.json(), + resp.headers if LORAX_DEBUG_MODE else None, + ) + return ClassifyResponse(entities=payload) - + def classify_batch(self, inputs: List[str]) -> List[ClassifyResponse]: """ Given a list of inputs, run token classification on the text using the model @@ -500,8 +525,8 @@ def classify_batch(self, inputs: List[str]) -> List[ClassifyResponse]: Args: inputs (`List[str]`): List of input texts - - Returns: + + Returns: List[Entities]: Entities found in the input text """ request = BatchRequest(inputs=inputs) @@ -516,10 +541,55 @@ def classify_batch(self, inputs: List[str]) -> List[ClassifyResponse]: payload = resp.json() if resp.status_code != 200: - raise parse_error(resp.status_code, resp.json(), resp.headers if LORAX_DEBUG_MODE else None) - + raise parse_error( + resp.status_code, + resp.json(), + resp.headers if LORAX_DEBUG_MODE else None, + ) + return [ClassifyResponse(entities=e) for e in payload] + def metrics( + self, format: Optional[Literal["json", "prometheus"]] = "prometheus" + ) -> MetricsResponse: + """ + Get the metrics of the model + + Args: + format (`Optional[Literal["json", "prometheus"]]`): + Format of the metrics + + Returns: + MetricsResponse: metrics in the specified format + """ + headers = copy.deepcopy(self.headers) + if format == "json": + if self.headers is None: + headers = {"Accept": "application/json"} + else: + headers = {**self.headers, "Accept": "application/json"} + + resp = requests.get( + self.metrics_endpoint, + headers=headers, + cookies=self.cookies, + timeout=self.timeout, + ) + + if format == "json": + payload = resp.json() + else: + payload = resp.text + + if resp.status_code != 200: + raise parse_error( + resp.status_code, + resp.json(), + resp.headers if LORAX_DEBUG_MODE else None, + ) + + return MetricsResponse(metrics=payload) + class AsyncClient: """Asynchronous Client to make calls to a LoRAX instance @@ -690,11 +760,15 @@ async def generate( async with ClientSession( headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: - async with session.post(self.base_url, json=request.dict(by_alias=True)) as resp: + async with session.post( + self.base_url, json=request.dict(by_alias=True) + ) as resp: payload = await resp.json() if resp.status != 200: - raise parse_error(resp.status, payload, resp.headers if LORAX_DEBUG_MODE else None) + raise parse_error( + resp.status, payload, resp.headers if LORAX_DEBUG_MODE else None + ) return Response(**payload[0]) async def generate_stream( @@ -720,7 +794,6 @@ async def generate_stream( response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None, details: bool = True, return_k_alternatives: Optional[int] = None, - ) -> AsyncIterator[StreamResponse]: """ Given a prompt, generate the following stream of tokens asynchronously @@ -814,10 +887,15 @@ async def generate_stream( async with ClientSession( headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: - async with session.post(self.base_url, json=request.dict(by_alias=True)) as resp: - + async with session.post( + self.base_url, json=request.dict(by_alias=True) + ) as resp: if resp.status != 200: - raise parse_error(resp.status, await resp.json(), resp.headers if LORAX_DEBUG_MODE else None) + raise parse_error( + resp.status, + await resp.json(), + resp.headers if LORAX_DEBUG_MODE else None, + ) # Parse ServerSentEvents async for byte_payload in resp.content: @@ -836,9 +914,12 @@ async def generate_stream( response = StreamResponse(**json_payload) except ValidationError: # If we failed to parse the payload, then it is an error payload - raise parse_error(resp.status, json_payload, resp.headers if LORAX_DEBUG_MODE else None) + raise parse_error( + resp.status, + json_payload, + resp.headers if LORAX_DEBUG_MODE else None, + ) yield response - async def embed(self, inputs: str) -> EmbedResponse: """ @@ -847,22 +928,25 @@ async def embed(self, inputs: str) -> EmbedResponse: Args: inputs (`str`): Input text - - Returns: + + Returns: Embeddings: computed embeddings """ request = Request(inputs=inputs) async with ClientSession( headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: - async with session.post(self.embed_endpoint, json=request.dict(by_alias=True)) as resp: + async with session.post( + self.embed_endpoint, json=request.dict(by_alias=True) + ) as resp: payload = await resp.json() if resp.status != 200: - raise parse_error(resp.status, payload, resp.headers if LORAX_DEBUG_MODE else None) + raise parse_error( + resp.status, payload, resp.headers if LORAX_DEBUG_MODE else None + ) return EmbedResponse(**payload) - async def classify(self, inputs: str) -> ClassifyResponse: """ Given inputs, run token classification on the text using the model @@ -870,17 +954,45 @@ async def classify(self, inputs: str) -> ClassifyResponse: Args: inputs (`str`): Input text - - Returns: + + Returns: Entities: Entities found in the input text """ request = Request(inputs=inputs) async with ClientSession( headers=self.headers, cookies=self.cookies, timeout=self.timeout ) as session: - async with session.post(self.classify_endpoint, json=request.dict(by_alias=True)) as resp: + async with session.post( + self.classify_endpoint, json=request.dict(by_alias=True) + ) as resp: payload = await resp.json() if resp.status != 200: - raise parse_error(resp.status, payload, resp.headers if LORAX_DEBUG_MODE else None) + raise parse_error( + resp.status, payload, resp.headers if LORAX_DEBUG_MODE else None + ) return ClassifyResponse(**payload) + + async def metrics( + self, format: Literal["json", "prometheus"] = "prometheus" + ) -> MetricsResponse: + """ + Get the metrics of the server + """ + headers = copy.deepcopy(self.headers) + if format == "json": + if headers is None: + headers = {"Accept": "application/json"} + else: + headers["Accept"] = "application/json" + + async with ClientSession( + headers=headers, cookies=self.cookies, timeout=self.timeout + ) as session: + async with session.get(self.metrics_endpoint, headers=headers) as resp: + if format == "json": + payload = await resp.json() + else: + payload = await resp.text() + + return MetricsResponse(metrics=payload) diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index 2fc98b7b7..6e2e19554 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -52,7 +52,9 @@ def validate_density(cls, v): @field_validator("majority_sign_method") def validate_majority_sign_method(cls, v): if v is not None and v not in MAJORITY_SIGN_METHODS: - raise ValidationError(f"`majority_sign_method` must be one of {MAJORITY_SIGN_METHODS}") + raise ValidationError( + f"`majority_sign_method` must be one of {MAJORITY_SIGN_METHODS}" + ) return v @@ -64,7 +66,9 @@ class ResponseFormat(BaseModel): model_config = ConfigDict(use_enum_values=True) type: ResponseFormatType - schema_spec: Optional[Union[Dict[str, Any], OrderedDict]] = Field(None, alias="schema") + schema_spec: Optional[Union[Dict[str, Any], OrderedDict]] = Field( + None, alias="schema" + ) class Parameters(BaseModel): @@ -121,13 +125,17 @@ def valid_adapter_id(self): adapter_id = self.adapter_id merged_adapters = self.merged_adapters if adapter_id is not None and merged_adapters is not None: - raise ValidationError("you must specify at most one of `adapter_id` or `merged_adapters`") + raise ValidationError( + "you must specify at most one of `adapter_id` or `merged_adapters`" + ) return self @field_validator("adapter_source") def valid_adapter_source(cls, v): if v is not None and v not in ADAPTER_SOURCES: - raise ValidationError(f"`adapter_source={v}` must be one of {ADAPTER_SOURCES}") + raise ValidationError( + f"`adapter_source={v}` must be one of {ADAPTER_SOURCES}" + ) return v @field_validator("best_of") @@ -215,8 +223,15 @@ def valid_input(cls, v): @field_validator("stream") def valid_best_of_stream(cls, field_value, values): parameters = values.data["parameters"] - if parameters is not None and parameters.best_of is not None and parameters.best_of > 1 and field_value: - raise ValidationError("`best_of` != 1 is not supported when `stream` == True") + if ( + parameters is not None + and parameters.best_of is not None + and parameters.best_of > 1 + and field_value + ): + raise ValidationError( + "`best_of` != 1 is not supported when `stream` == True" + ) return field_value @@ -237,8 +252,15 @@ def valid_input(cls, v): @field_validator("stream") def valid_best_of_stream(cls, field_value, values): parameters = values.data["parameters"] - if parameters is not None and parameters.best_of is not None and parameters.best_of > 1 and field_value: - raise ValidationError("`best_of` != 1 is not supported when `stream` == True") + if ( + parameters is not None + and parameters.best_of is not None + and parameters.best_of > 1 + and field_value + ): + raise ValidationError( + "`best_of` != 1 is not supported when `stream` == True" + ) return field_value @@ -370,6 +392,12 @@ class EmbedResponse(BaseModel): # Embeddings embeddings: Optional[List[float]] + class ClassifyResponse(BaseModel): # Classifications - entities: Optional[List[dict]] \ No newline at end of file + entities: Optional[List[dict]] + + +class MetricsResponse(BaseModel): + # Metrics (can be in JSON or Prometheus [string] format) + metrics: Optional[str | dict] diff --git a/router/src/server.rs b/router/src/server.rs index ce2785cb6..c3cf83376 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -20,6 +20,8 @@ use crate::{ TokenizeRequest, TokenizeResponse, Tool, ToolCall, ToolChoice, UsageInfo, Validation, }; use axum::extract::Extension; +use axum::extract::Query; +use axum::http::header; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; @@ -35,6 +37,7 @@ use reqwest_middleware::ClientBuilder; use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use serde_json::Value; use std::cmp; +use std::collections::HashMap; use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; @@ -1180,10 +1183,88 @@ async fn generate_stream_with_callback( get, tag = "LoRAX", path = "/metrics", -responses((status = 200, description = "Prometheus Metrics", body = String)) +params( + ("format", Query, description = "Optional format parameter (prometheus|json)", example = "json") +), +responses( + (status = 200, description = "Prometheus or JSON Metrics", + content( + ("text/plain" = String), + ("application/json" = Object) + )) +) )] -async fn metrics(prom_handle: Extension) -> String { - prom_handle.render() +async fn metrics( + prom_handle: Extension, + format: Option>, + headers: HeaderMap, +) -> Response { + // Check format query param first, then Accept header + let want_json = format + .map(|f| f.0.to_lowercase() == "json") + .unwrap_or_else(|| { + headers + .get(header::ACCEPT) + .and_then(|h| h.to_str().ok()) + .map(|h| h.contains("application/json")) + .unwrap_or(false) + }); + + let prometheus_text = prom_handle.render(); + + if want_json { + // Parse the Prometheus text format into a structured format + let mut counters = HashMap::new(); + let mut gauges = HashMap::new(); + + // Basic parsing of Prometheus format + for line in prometheus_text.lines() { + if line.starts_with('#') || line.is_empty() { + continue; + } + + if let Some((name, value)) = parse_metric_line(line) { + if name.ends_with("_total") { + counters.insert(name, value); + } else { + gauges.insert(name, value); + } + } + } + + let json_response = json!({ + "counters": counters, + "gauges": gauges, + }); + + ( + StatusCode::OK, + [(header::CONTENT_TYPE, "application/json")], + Json(json_response), + ) + .into_response() + } else { + // Return default Prometheus format + ( + StatusCode::OK, + [(header::CONTENT_TYPE, "text/plain")], + prometheus_text, + ) + .into_response() + } +} + +/// Helper function to parse a Prometheus metric line +fn parse_metric_line(line: &str) -> Option<(String, f64)> { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() != 2 { + return None; + } + + let name = parts[0].to_string(); + let value = parts[1].parse().ok()?; + + Some((name, value)) } async fn request_logger( @@ -1794,10 +1875,6 @@ async fn classify( "lorax_request_inference_duration", inference_time.as_secs_f64() ); - metrics::histogram!( - "lorax_request_classify_output_count", - response.predictions.len() as f64 - ); tracing::debug!("Output: {:?}", response.predictions); tracing::info!("Success"); From 1380b3b10e3333c9bbff0ae80901e201989e3da2 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Thu, 2 Jan 2025 16:45:34 -0800 Subject: [PATCH 2/9] clean up client code --- clients/python/lorax/client.py | 46 ++++++++++++++++------------------ 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 94327fcb2..4aaf1e79a 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -2,7 +2,6 @@ import logging import requests from requests.adapters import HTTPAdapter, Retry -import copy import os from aiohttp import ClientSession, ClientTimeout from pydantic import ValidationError @@ -562,12 +561,11 @@ def metrics( Returns: MetricsResponse: metrics in the specified format """ - headers = copy.deepcopy(self.headers) - if format == "json": - if self.headers is None: - headers = {"Accept": "application/json"} - else: - headers = {**self.headers, "Accept": "application/json"} + # Simplified header assignment + headers = { + **(self.headers or {}), + "Accept": "application/json" if format == "json" else "text/plain", + } resp = requests.get( self.metrics_endpoint, @@ -576,19 +574,17 @@ def metrics( timeout=self.timeout, ) - if format == "json": - payload = resp.json() - else: - payload = resp.text + # Unified payload handling + payload = resp.json() if format == "json" else resp.text if resp.status_code != 200: raise parse_error( resp.status_code, - resp.json(), + payload, resp.headers if LORAX_DEBUG_MODE else None, ) - return MetricsResponse(metrics=payload) + return payload class AsyncClient: @@ -979,20 +975,22 @@ async def metrics( """ Get the metrics of the server """ - headers = copy.deepcopy(self.headers) - if format == "json": - if headers is None: - headers = {"Accept": "application/json"} - else: - headers["Accept"] = "application/json" + # Simplified header assignment + headers = { + **(self.headers or {}), + "Accept": "application/json" if format == "json" else "text/plain", + } async with ClientSession( headers=headers, cookies=self.cookies, timeout=self.timeout ) as session: - async with session.get(self.metrics_endpoint, headers=headers) as resp: - if format == "json": - payload = await resp.json() - else: - payload = await resp.text() + async with session.get(self.metrics_endpoint) as resp: + # Unified payload handling + payload = await resp.json() if format == "json" else await resp.text() + + if resp.status != 200: + raise parse_error( + resp.status, payload, resp.headers if LORAX_DEBUG_MODE else None + ) return MetricsResponse(metrics=payload) From 16d94574a70e7b981f66350bf7b0a185a807ab0e Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Thu, 2 Jan 2025 17:37:31 -0800 Subject: [PATCH 3/9] Clean up server code --- router/src/server.rs | 107 +++++++++++++++++++++++++++---------------- 1 file changed, 68 insertions(+), 39 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index c3cf83376..2417def60 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -54,6 +54,8 @@ use tower_http::cors::{ use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; +use serde::{Deserialize, Serialize}; + pub static DEFAULT_ADAPTER_SOURCE: OnceCell = OnceCell::new(); @@ -1178,6 +1180,70 @@ async fn generate_stream_with_callback( (headers, stream) } +#[derive(Serialize, Deserialize, Debug)] +struct MetricFamily { + r#type: String, + data: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +struct DataPoint { + key: String, + value: f64, +} + +fn parse_text_to_metrics(text: &str) -> HashMap { + let mut metrics = HashMap::new(); + let mut current_metric = String::new(); + let mut current_type = String::new(); + + for line in text.lines() { + if line.is_empty() { + continue; + } + + if line.starts_with("# TYPE ") { + // Extract metric name and type from TYPE declaration + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() >= 4 { + current_metric = parts[2].to_string(); + current_type = parts[3].to_string(); + metrics.insert(current_metric.clone(), MetricFamily { + r#type: current_type.clone(), + data: Vec::new(), + }); + continue; + } + } + + // Parse metric line + if let Some(metric_family) = metrics.get_mut(¤t_metric) { + // Split into name and value parts + let mut parts = line.split_whitespace(); + if let (Some(name_part), Some(value_str)) = (parts.next(), parts.next()) { + if let Ok(value) = value_str.parse::() { + let key = if name_part.contains('{') { + name_part.to_string() + } else if name_part.ends_with("_sum") { + "sum".to_string() + } else if name_part.ends_with("_count") { + "count".to_string() + } else { + "".to_string() + }; + + metric_family.data.push(DataPoint { + key, + value, + }); + } + } + } + } + + metrics +} + /// Prometheus metrics scrape endpoint #[utoipa::path( get, @@ -1199,7 +1265,6 @@ async fn metrics( format: Option>, headers: HeaderMap, ) -> Response { - // Check format query param first, then Accept header let want_json = format .map(|f| f.0.to_lowercase() == "json") .unwrap_or_else(|| { @@ -1213,38 +1278,15 @@ async fn metrics( let prometheus_text = prom_handle.render(); if want_json { - // Parse the Prometheus text format into a structured format - let mut counters = HashMap::new(); - let mut gauges = HashMap::new(); - - // Basic parsing of Prometheus format - for line in prometheus_text.lines() { - if line.starts_with('#') || line.is_empty() { - continue; - } - - if let Some((name, value)) = parse_metric_line(line) { - if name.ends_with("_total") { - counters.insert(name, value); - } else { - gauges.insert(name, value); - } - } - } - - let json_response = json!({ - "counters": counters, - "gauges": gauges, - }); + let metrics = parse_text_to_metrics(&prometheus_text); ( StatusCode::OK, [(header::CONTENT_TYPE, "application/json")], - Json(json_response), + Json(metrics), ) .into_response() } else { - // Return default Prometheus format ( StatusCode::OK, [(header::CONTENT_TYPE, "text/plain")], @@ -1254,19 +1296,6 @@ async fn metrics( } } -/// Helper function to parse a Prometheus metric line -fn parse_metric_line(line: &str) -> Option<(String, f64)> { - let parts: Vec<&str> = line.split_whitespace().collect(); - if parts.len() != 2 { - return None; - } - - let name = parts[0].to_string(); - let value = parts[1].parse().ok()?; - - Some((name, value)) -} - async fn request_logger( request_logger_url: Option, mut rx: mpsc::Receiver<(i64, String, String, String, String, String)>, From f1293e58abf401564512d246d72beae44f1caba3 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Thu, 2 Jan 2025 17:52:21 -0800 Subject: [PATCH 4/9] Fix client tests --- clients/python/lorax/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index 6e2e19554..2b818a6db 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -400,4 +400,4 @@ class ClassifyResponse(BaseModel): class MetricsResponse(BaseModel): # Metrics (can be in JSON or Prometheus [string] format) - metrics: Optional[str | dict] + metrics: Optional[Union[str, dict]] From 75c8d0896d28fee39ae4a3e8b1d290c59ccf1a1c Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Thu, 2 Jan 2025 17:57:06 -0800 Subject: [PATCH 5/9] cleanup --- router/src/server.rs | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 2417def60..8819af589 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1216,26 +1216,20 @@ fn parse_text_to_metrics(text: &str) -> HashMap { } } - // Parse metric line + // Parse metric line if it belongs to current metric family if let Some(metric_family) = metrics.get_mut(¤t_metric) { - // Split into name and value parts let mut parts = line.split_whitespace(); - if let (Some(name_part), Some(value_str)) = (parts.next(), parts.next()) { + if let (Some(metric_name), Some(value_str)) = (parts.next(), parts.next()) { if let Ok(value) = value_str.parse::() { - let key = if name_part.contains('{') { - name_part.to_string() - } else if name_part.ends_with("_sum") { - "sum".to_string() - } else if name_part.ends_with("_count") { - "count".to_string() - } else { - "".to_string() + let key = match metric_name { + name if name.contains('{') => name.to_string(), + name if name.ends_with("_sum") => "sum".to_string(), + name if name.ends_with("_count") => "count".to_string(), + _ => "".to_string() }; - metric_family.data.push(DataPoint { - key, - value, - }); + // Add the parsed metric data point + metric_family.data.push(DataPoint { key, value }); } } } From c02928dbbbb3ea7eac335e7144e0a643207bd6a8 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Thu, 2 Jan 2025 17:59:27 -0800 Subject: [PATCH 6/9] cleanup --- clients/python/lorax/client.py | 2 -- router/src/server.rs | 4 ++++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 4aaf1e79a..af8769987 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -561,7 +561,6 @@ def metrics( Returns: MetricsResponse: metrics in the specified format """ - # Simplified header assignment headers = { **(self.headers or {}), "Accept": "application/json" if format == "json" else "text/plain", @@ -975,7 +974,6 @@ async def metrics( """ Get the metrics of the server """ - # Simplified header assignment headers = { **(self.headers or {}), "Accept": "application/json" if format == "json" else "text/plain", diff --git a/router/src/server.rs b/router/src/server.rs index 8819af589..6bbad480f 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1898,6 +1898,10 @@ async fn classify( "lorax_request_inference_duration", inference_time.as_secs_f64() ); + metrics::histogram!( + "lorax_request_classify_output_count", + response.predictions.len() as f64 + ); tracing::debug!("Output: {:?}", response.predictions); tracing::info!("Success"); From 6322de72f859e5335ed56691a5487da66ae7e036 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Thu, 2 Jan 2025 18:10:40 -0800 Subject: [PATCH 7/9] More cleanup --- clients/python/lorax/client.py | 8 +++----- clients/python/lorax/types.py | 5 ----- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index af8769987..b7b7d2a0b 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -17,7 +17,6 @@ ResponseFormat, EmbedResponse, ClassifyResponse, - MetricsResponse, ) from lorax.errors import parse_error @@ -550,7 +549,7 @@ def classify_batch(self, inputs: List[str]) -> List[ClassifyResponse]: def metrics( self, format: Optional[Literal["json", "prometheus"]] = "prometheus" - ) -> MetricsResponse: + ) -> Union[str, dict]: """ Get the metrics of the model @@ -970,7 +969,7 @@ async def classify(self, inputs: str) -> ClassifyResponse: async def metrics( self, format: Literal["json", "prometheus"] = "prometheus" - ) -> MetricsResponse: + ) -> Union[str, dict]: """ Get the metrics of the server """ @@ -990,5 +989,4 @@ async def metrics( raise parse_error( resp.status, payload, resp.headers if LORAX_DEBUG_MODE else None ) - - return MetricsResponse(metrics=payload) + return payload diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index 2b818a6db..c26623e2e 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -396,8 +396,3 @@ class EmbedResponse(BaseModel): class ClassifyResponse(BaseModel): # Classifications entities: Optional[List[dict]] - - -class MetricsResponse(BaseModel): - # Metrics (can be in JSON or Prometheus [string] format) - metrics: Optional[Union[str, dict]] From 907c10a4d1432e93b391d01de04a28385e403dfc Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Thu, 2 Jan 2025 18:19:02 -0800 Subject: [PATCH 8/9] Fix rust warning --- router/src/server.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 6bbad480f..9f53fb573 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1195,7 +1195,6 @@ struct DataPoint { fn parse_text_to_metrics(text: &str) -> HashMap { let mut metrics = HashMap::new(); let mut current_metric = String::new(); - let mut current_type = String::new(); for line in text.lines() { if line.is_empty() { @@ -1203,13 +1202,13 @@ fn parse_text_to_metrics(text: &str) -> HashMap { } if line.starts_with("# TYPE ") { - // Extract metric name and type from TYPE declaration + // Extract metric name from TYPE declaration + // # TYPE let parts: Vec<&str> = line.split_whitespace().collect(); if parts.len() >= 4 { current_metric = parts[2].to_string(); - current_type = parts[3].to_string(); metrics.insert(current_metric.clone(), MetricFamily { - r#type: current_type.clone(), + r#type: parts[3].to_string(), // Metric type -> histogram, counter, etc data: Vec::new(), }); continue; @@ -1225,7 +1224,7 @@ fn parse_text_to_metrics(text: &str) -> HashMap { name if name.contains('{') => name.to_string(), name if name.ends_with("_sum") => "sum".to_string(), name if name.ends_with("_count") => "count".to_string(), - _ => "".to_string() + _ => "".to_string(), }; // Add the parsed metric data point From 155f72779306d16132beea414422a3beed19bff5 Mon Sep 17 00:00:00 2001 From: Arnav Garg Date: Thu, 2 Jan 2025 18:32:46 -0800 Subject: [PATCH 9/9] Cargo rust formatting --- router/src/server.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 9f53fb573..3eb24521d 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -35,6 +35,7 @@ use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use once_cell::sync::OnceCell; use reqwest_middleware::ClientBuilder; use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; +use serde::{Deserialize, Serialize}; use serde_json::Value; use std::cmp; use std::collections::HashMap; @@ -54,8 +55,6 @@ use tower_http::cors::{ use tracing::{info_span, instrument, Instrument}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; -use serde::{Deserialize, Serialize}; - pub static DEFAULT_ADAPTER_SOURCE: OnceCell = OnceCell::new(); @@ -1207,10 +1206,13 @@ fn parse_text_to_metrics(text: &str) -> HashMap { let parts: Vec<&str> = line.split_whitespace().collect(); if parts.len() >= 4 { current_metric = parts[2].to_string(); - metrics.insert(current_metric.clone(), MetricFamily { - r#type: parts[3].to_string(), // Metric type -> histogram, counter, etc - data: Vec::new(), - }); + metrics.insert( + current_metric.clone(), + MetricFamily { + r#type: parts[3].to_string(), // Metric type -> histogram, counter, etc + data: Vec::new(), + }, + ); continue; } }