Skip to content

Commit

Permalink
Merge pull request #114 from cloudera/mob/main
Browse files Browse the repository at this point in the history
Store and Use Reranking Models
  • Loading branch information
ewilliams-cloudera authored Jan 27, 2025
2 parents eb99b62 + 6a8cd16 commit 0affac5
Show file tree
Hide file tree
Showing 28 changed files with 517 additions and 62 deletions.
5 changes: 1 addition & 4 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,4 @@ CAII_DOMAIN=
USE_SYSTEM_UV=true

# set this to true to enable enhanced pdf processing with docling
USE_ENHANCED_PDF_PROCESSING=false

# enable two-stage retrieval
ENABLE_TWO_STAGE_RETRIEVAL=false
USE_ENHANCED_PDF_PROCESSING=false
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
!.idea/google-java-format.xml
chat_store.json
databases/
**/docling-output.txt
**/docling-output.txt
**/.DS_Store
3 changes: 3 additions & 0 deletions backend/src/main/java/com/cloudera/cai/rag/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ public record Session(
String updatedById,
Instant lastInteractionTime,
String inferenceModel,
String rerankModel,
Integer responseChunks) {

public static Session fromCreateRequest(CreateSession input, String username) {
Expand All @@ -128,6 +129,7 @@ public static Session fromCreateRequest(CreateSession input, String username) {
username,
null,
input.inferenceModel(),
input.rerankModel(),
input.responseChunks());
}
}
Expand All @@ -136,5 +138,6 @@ public record CreateSession(
String name,
@Singular List<Long> dataSourceIds,
String inferenceModel,
String rerankModel,
Integer responseChunks) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ public Long create(Types.Session input) {
handle -> {
var sql =
"""
INSERT INTO CHAT_SESSION (name, created_by_id, updated_by_id, inference_model, response_chunks)
VALUES (:name, :createdById, :updatedById, :inferenceModel, :responseChunks)
INSERT INTO CHAT_SESSION (name, created_by_id, updated_by_id, inference_model, rerank_model, response_chunks)
VALUES (:name, :createdById, :updatedById, :inferenceModel, :rerankModel, :responseChunks)
""";
Long id = insertSession(input, handle, sql);
insertSessionDataSources(handle, id, input.dataSourceIds());
Expand Down Expand Up @@ -125,6 +125,7 @@ private Stream<Types.Session.SessionBuilder> querySessions(Query query) {
.name(rowView.getColumn("name", String.class))
.inferenceModel(rowView.getColumn("inference_model", String.class))
.responseChunks(rowView.getColumn("response_chunks", Integer.class))
.rerankModel(rowView.getColumn("rerank_model", String.class))
.createdById(rowView.getColumn("created_by_id", String.class))
.timeCreated(rowView.getColumn("time_created", Instant.class))
.updatedById(rowView.getColumn("updated_by_id", String.class))
Expand Down Expand Up @@ -170,7 +171,8 @@ public void update(Types.Session input) {
var sql =
"""
UPDATE CHAT_SESSION
SET name = :name, updated_by_id = :updatedById, inference_model = :inferenceModel, response_chunks = :responseChunks, time_updated = :timeUpdated
SET name = :name, updated_by_id = :updatedById, inference_model = :inferenceModel,
response_chunks = :responseChunks, time_updated = :timeUpdated, rerank_model = :rerankModel
WHERE id = :id
""";
handle.createUpdate(sql).bindMethods(updatedInput).execute();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
*
* Applicable Open Source License: Apache 2.0
*
* NOTE: Cloudera open source products are modular software products
* made up of hundreds of individual components, each of which was
* individually copyrighted. Each Cloudera open source product is a
* collective work under U.S. Copyright Law. Your license to use the
* collective work is as provided in your written agreement with
* Cloudera. Used apart from the collective work, this file is
* licensed for your use pursuant to the open source license
* identified above.
*
* This code is provided to you pursuant a written agreement with
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
* this code. If you do not have a written agreement with Cloudera nor
* with an authorized and properly licensed third party, you do not
* have any rights to access nor to use this code.
*
* Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
*/

SET MODE MYSQL;

BEGIN;

ALTER TABLE CHAT_SESSION DROP COLUMN rerank_model;

COMMIT;
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
*
* Applicable Open Source License: Apache 2.0
*
* NOTE: Cloudera open source products are modular software products
* made up of hundreds of individual components, each of which was
* individually copyrighted. Each Cloudera open source product is a
* collective work under U.S. Copyright Law. Your license to use the
* collective work is as provided in your written agreement with
* Cloudera. Used apart from the collective work, this file is
* licensed for your use pursuant to the open source license
* identified above.
*
* This code is provided to you pursuant a written agreement with
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
* this code. If you do not have a written agreement with Cloudera nor
* with an authorized and properly licensed third party, you do not
* have any rights to access nor to use this code.
*
* Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
*/

SET MODE MYSQL;

BEGIN;

ALTER TABLE CHAT_SESSION ADD COLUMN rerank_model VARCHAR(255);


COMMIT;
4 changes: 3 additions & 1 deletion backend/src/main/resources/migrations/migrations.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@
15_add_summarization_model.down.sql
15_add_summarization_model.up.sql
16_add_doc_status.down.sql
16_add_doc_status.up.sql
16_add_doc_status.up.sql
17_add_chat_rerank_model.down.sql
17_add_chat_rerank_model.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
*
* Applicable Open Source License: Apache 2.0
*
* NOTE: Cloudera open source products are modular software products
* made up of hundreds of individual components, each of which was
* individually copyrighted. Each Cloudera open source product is a
* collective work under U.S. Copyright Law. Your license to use the
* collective work is as provided in your written agreement with
* Cloudera. Used apart from the collective work, this file is
* licensed for your use pursuant to the open source license
* identified above.
*
* This code is provided to you pursuant a written agreement with
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
* this code. If you do not have a written agreement with Cloudera nor
* with an authorized and properly licensed third party, you do not
* have any rights to access nor to use this code.
*
* Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
*/

BEGIN;

ALTER TABLE CHAT_SESSION DROP COLUMN rerank_model;

COMMIT;
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
*
* Applicable Open Source License: Apache 2.0
*
* NOTE: Cloudera open source products are modular software products
* made up of hundreds of individual components, each of which was
* individually copyrighted. Each Cloudera open source product is a
* collective work under U.S. Copyright Law. Your license to use the
* collective work is as provided in your written agreement with
* Cloudera. Used apart from the collective work, this file is
* licensed for your use pursuant to the open source license
* identified above.
*
* This code is provided to you pursuant a written agreement with
* (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
* this code. If you do not have a written agreement with Cloudera nor
* with an authorized and properly licensed third party, you do not
* have any rights to access nor to use this code.
*
* Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the
* contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
* KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
* WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
* IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
* AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
* ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
* OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
* CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
* RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
* BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
* DATA.
*/

BEGIN;

ALTER TABLE CHAT_SESSION ADD COLUMN rerank_model VARCHAR(255);


COMMIT;
15 changes: 13 additions & 2 deletions backend/src/test/java/com/cloudera/cai/rag/TestData.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,22 @@
public class TestData {
public static Types.Session createTestSessionInstance(String sessionName) {
return new Types.Session(
null, sessionName, List.of(1L, 2L, 3L), null, null, null, null, null, "test-model", 3);
null,
sessionName,
List.of(1L, 2L, 3L),
null,
null,
null,
null,
null,
"test-model",
"test-rerank-model",
3);
}

public static Types.CreateSession createSessionInstance(String sessionName) {
return new Types.CreateSession(sessionName, List.of(1L, 2L, 3L), "test-model", 3);
return new Types.CreateSession(
sessionName, List.of(1L, 2L, 3L), "test-model", "test-rerank-model", 3);
}

public static Types.RagDataSource createTestDataSourceInstance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ void create() throws JsonProcessingException {
assertThat(result.id()).isNotNull();
assertThat(result.name()).isEqualTo(sessionName);
assertThat(result.inferenceModel()).isEqualTo(input.inferenceModel());
assertThat(result.rerankModel()).isEqualTo(input.rerankModel());
assertThat(result.responseChunks()).isEqualTo(input.responseChunks());
assertThat(result.dataSourceIds()).containsExactlyInAnyOrder(1L, 2L, 3L);
assertThat(result.timeCreated()).isNotNull();
Expand Down Expand Up @@ -95,6 +96,7 @@ void update() throws JsonProcessingException {
var updatedResponseChunks = 1;
var updatedInferenceModel = "new-model-name";
var updatedName = "new-name";
var updatedRerankModel = "new-rerank-model";

request = new MockHttpServletRequest();
request.setCookies(
Expand All @@ -106,12 +108,14 @@ void update() throws JsonProcessingException {
insertedSession
.withInferenceModel(updatedInferenceModel)
.withResponseChunks(updatedResponseChunks)
.withRerankModel(updatedRerankModel)
.withName(updatedName),
request);

assertThat(updatedSession.id()).isNotNull();
assertThat(updatedSession.name()).isEqualTo(updatedName);
assertThat(updatedSession.inferenceModel()).isEqualTo(updatedInferenceModel);
assertThat(updatedSession.rerankModel()).isEqualTo(updatedRerankModel);
assertThat(updatedSession.responseChunks()).isEqualTo(updatedResponseChunks);
assertThat(updatedSession.dataSourceIds()).containsExactlyInAnyOrder(1L, 2L, 3L);
assertThat(updatedSession.timeCreated()).isNotNull();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*******************************************************************************
/*
* CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
* (C) Cloudera, Inc. 2024
* All rights reserved.
Expand Down
14 changes: 14 additions & 0 deletions llm-service/app/routers/index/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
get_model_source,
test_embedding_model,
test_llm_model,
get_available_rerank_models,
test_reranking_model,
)

router = APIRouter(prefix="/models", tags=["Models"])
Expand All @@ -65,6 +67,12 @@ def get_llm_embedding_models() -> List[ModelResponse]:
return get_available_embedding_models()


@router.get("/reranking", summary="Get reranking models.")
@exceptions.propagates
def get_reranking_models() -> List[ModelResponse]:
return get_available_rerank_models()


@router.get("/model_source", summary="Model source enabled - Bedrock or CAII")
@exceptions.propagates
def get_model() -> ModelSource:
Expand All @@ -81,3 +89,9 @@ def llm_model_test(model_name: str) -> Literal["ok"]:
@exceptions.propagates
def embedding_model_test(model_name: str) -> str:
return test_embedding_model(model_name)


@router.get("/reranking/{model_name}/test", summary="Test Reranking model.")
@exceptions.propagates
def reranking_model_test(model_name: str) -> str:
return test_reranking_model(model_name)
Loading

0 comments on commit 0affac5

Please sign in to comment.