Skip to content

Commit

Permalink
Merge pull request #152 from mdevino/generated-detection-endpoint
Browse files Browse the repository at this point in the history
Implement /detect/generated
  • Loading branch information
gkumbhat authored Aug 6, 2024
2 parents e3feea8 + 7cf5d85 commit 73e05e0
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 11 deletions.
72 changes: 72 additions & 0 deletions docs/api/orchestrator_openapi_0_1_0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,40 @@ paths:
schema:
$ref: '#/components/schemas/Error'

/api/v2/text/detection/generated:
post:
tags:
- Detection tasks
summary: Detection task performing detection on prompt and generated text
operationId: >-
api_v2_detection_text_generated_unary_handler
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/GeneratedTextDetectionRequest'
required: true
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/GeneratedTextDetectionResponse'
'404':
description: Resource Not Found
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/Error'


components:
schemas:
DetectionContentRequest:
Expand Down Expand Up @@ -400,6 +434,44 @@ components:
title: Generation Detection Response
required: ["generated_text", "detections"]

GeneratedTextDetectionRequest:
properties:
prompt:
type: string
title: Prompt
generated_text:
type: string
title: Generated Text
detectors:
type: object
title: Detectors
default: {}
example:
generated-detection-v1-model-en: {}
type: object
required: ["generated_text", "prompt", "detectors"]
title: Generated-Text Detection Request
GeneratedTextDetectionResponse:
properties:
detections:
type: array
items:
type: object
title: Detection Object
properties:
detection_type:
type: string
title: Detection Type
detection:
type: string
title: Detection
score:
type: number
title: Score
required: ["detections"]
title: Generated Text Detection Docs Response


ClassifiedGeneratedTextResult:
properties:
generated_text:
Expand Down
41 changes: 41 additions & 0 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,47 @@ pub struct ContextDocsResult {
pub detections: Vec<DetectionResult>,
}

/// The request format expected in the /api/v2/text/detect/generated endpoint.
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DetectionOnGeneratedHttpRequest {
/// The prompt to be sent to the LLM.
#[serde(rename = "prompt")]
pub prompt: String,

/// The text generated by the LLM.
#[serde(rename = "generated_text")]
pub generated_text: String,

/// The map of detectors to be used, along with their respective parameters, e.g. thresholds.
#[serde(rename = "detectors")]
pub detectors: HashMap<String, DetectorParams>,
}

impl DetectionOnGeneratedHttpRequest {
/// Upfront validation of user request
pub fn validate(&self) -> Result<(), ValidationError> {
// Validate required parameters
if self.prompt.is_empty() {
return Err(ValidationError::Required("prompt".into()));
}
if self.generated_text.is_empty() {
return Err(ValidationError::Required("generated_text".into()));
}
if self.detectors.is_empty() {
return Err(ValidationError::Required("detectors".into()));
}
Ok(())
}
}

/// The response format of the /api/v2/text/detection/generated endpoint
#[derive(Default, Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct DetectionOnGenerationResult {
/// Detection results
#[serde(rename = "detections")]
pub detections: Vec<DetectionResult>,
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
33 changes: 30 additions & 3 deletions src/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ use crate::{
},
config::{GenerationProvider, OrchestratorConfig},
models::{
ContextDocsHttpRequest, DetectorParams, GenerationWithDetectionHttpRequest,
GuardrailsConfig, GuardrailsHttpRequest, GuardrailsTextGenerationParameters,
TextContentDetectionHttpRequest,
ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest, DetectorParams,
GenerationWithDetectionHttpRequest, GuardrailsConfig, GuardrailsHttpRequest,
GuardrailsTextGenerationParameters, TextContentDetectionHttpRequest,
},
};

Expand Down Expand Up @@ -263,6 +263,33 @@ impl ContextDocsDetectionTask {
}
}

/// Task for the /api/v2/text/detection/generated endpoint
#[derive(Debug)]
pub struct DetectionOnGenerationTask {
/// Request unique identifier
pub request_id: Uuid,

/// User prompt to be sent to the LLM
pub prompt: String,

/// Text generated by the LLM
pub generated_text: String,

/// Detectors configuration
pub detectors: HashMap<String, DetectorParams>,
}

impl DetectionOnGenerationTask {
pub fn new(request_id: Uuid, request: DetectionOnGeneratedHttpRequest) -> Self {
Self {
request_id,
prompt: request.prompt,
generated_text: request.generated_text,
detectors: request.detectors,
}
}
}

#[allow(dead_code)]
#[derive(Debug)]
pub struct StreamingClassificationWithGenTask {
Expand Down
71 changes: 65 additions & 6 deletions src/orchestrator/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ use tracing::{debug, error, info, instrument};

use super::{
apply_masks, get_chunker_ids, Chunk, ClassificationWithGenTask, Context,
ContextDocsDetectionTask, Error, GenerationWithDetectionTask, Orchestrator,
TextContentDetectionTask,
ContextDocsDetectionTask, DetectionOnGenerationTask, Error, GenerationWithDetectionTask,
Orchestrator, TextContentDetectionTask,
};
use crate::{
clients::detector::{
ContentAnalysisRequest, ContextDocsDetectionRequest, ContextType,
GenerationDetectionRequest,
},
models::{
ClassifiedGeneratedTextResult, ContextDocsResult, DetectionResult, DetectorParams,
GenerationWithDetectionResult, GuardrailsTextGenerationParameters, InputWarning,
InputWarningReason, TextContentDetectionResult, TextGenTokenClassificationResults,
TokenClassificationResult,
ClassifiedGeneratedTextResult, ContextDocsResult, DetectionOnGenerationResult,
DetectionResult, DetectorParams, GenerationWithDetectionResult,
GuardrailsTextGenerationParameters, InputWarning, InputWarningReason,
TextContentDetectionResult, TextGenTokenClassificationResults, TokenClassificationResult,
},
orchestrator::UNSUITABLE_INPUT_MESSAGE,
pb::caikit::runtime::chunkers,
Expand Down Expand Up @@ -313,6 +313,65 @@ impl Orchestrator {
}
}
}

/// Handles detections on generated text (without performing generation)
pub async fn handle_generated_text_detection(
&self,
task: DetectionOnGenerationTask,
) -> Result<DetectionOnGenerationResult, Error> {
info!(
request_id = ?task.request_id,
detectors = ?task.detectors,
"handling detection on generated content task"
);
let ctx = self.ctx.clone();
let task_handle = tokio::spawn(async move {
// call detection
let detections = try_join_all(
task.detectors
.iter()
.map(|(detector_id, detector_params)| {
let ctx = ctx.clone();
let detector_id = detector_id.clone();
let detector_params = detector_params.clone();
let prompt = task.prompt.clone();
let generated_text = task.generated_text.clone();
async {
detect_for_generation(
ctx,
detector_id,
detector_params,
prompt,
generated_text,
)
.await
}
})
.collect::<Vec<_>>(),
)
.await?
.into_iter()
.flatten()
.collect::<Vec<_>>();

Ok(DetectionOnGenerationResult { detections })
});
match task_handle.await {
// Task completed successfully
Ok(Ok(result)) => Ok(result),
// Task failed, return error propagated from child task that failed
Ok(Err(error)) => {
error!(request_id = ?task.request_id, %error, "detection on generated content task failed");
Err(error)
}
// Task cancelled or panicked
Err(error) => {
let error = error.into();
error!(request_id = ?task.request_id, %error, "detection on generated content task failed");
Err(error)
}
}
}
}

/// Handles input detection task.
Expand Down
29 changes: 27 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ use crate::{
config::OrchestratorConfig,
models::{self},
orchestrator::{
self, ClassificationWithGenTask, ContextDocsDetectionTask, GenerationWithDetectionTask,
Orchestrator, StreamingClassificationWithGenTask, TextContentDetectionTask,
self, ClassificationWithGenTask, ContextDocsDetectionTask, DetectionOnGenerationTask,
GenerationWithDetectionTask, Orchestrator, StreamingClassificationWithGenTask,
TextContentDetectionTask,
},
};

Expand Down Expand Up @@ -163,6 +164,10 @@ pub async fn run(
&format!("{}/detection/context", TEXT_API_PREFIX),
post(detect_context_documents),
)
.route(
&format!("{}/detection/generated", TEXT_API_PREFIX),
post(detect_generated),
)
.with_state(shared_state);

// (2c) Generate main guardrails server handle based on whether TLS is needed
Expand Down Expand Up @@ -377,6 +382,26 @@ async fn detect_context_documents(
}
}

async fn detect_generated(
State(state): State<Arc<ServerState>>,
WithRejection(Json(request), _): WithRejection<
Json<models::DetectionOnGeneratedHttpRequest>,
Error,
>,
) -> Result<impl IntoResponse, Error> {
let request_id = Uuid::new_v4();
request.validate()?;
let task = DetectionOnGenerationTask::new(request_id, request);
match state
.orchestrator
.handle_generated_text_detection(task)
.await
{
Ok(response) => Ok(Json(response).into_response()),
Err(error) => Err(error.into()),
}
}

/// Shutdown signal handler
async fn shutdown_signal() {
let ctrl_c = async {
Expand Down

0 comments on commit 73e05e0

Please sign in to comment.