Skip to content

Commit

Permalink
fix: fix knowledge base model deployment when selecting bedrock embed…
Browse files Browse the repository at this point in the history
…ding
  • Loading branch information
IcyKallen committed Dec 25, 2024
1 parent 89b2381 commit 4e8f946
Showing 1 changed file with 53 additions and 47 deletions.
100 changes: 53 additions & 47 deletions source/infrastructure/lib/model/model-construct.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ export class ModelConstruct extends NestedStack implements ModelConstructOutputs
modelAccount = Aws.ACCOUNT_ID;
modelRegion = Aws.REGION;
modelIamHelper: IAMHelper;
modelExecutionRole?: iam.Role;
modelExecutionRole?: iam.Role = undefined;
modelImageUrlDomain?: string;
modelPublicEcrAccount?: string;
modelVariantName?: string;
Expand All @@ -81,56 +81,33 @@ export class ModelConstruct extends NestedStack implements ModelConstructOutputs
super(scope, id);
this.modelIamHelper = props.sharedConstructOutputs.iamHelper;

// this.defaultEmbeddingModelName = "cohere.embed-english-v3"
// handle embedding model name setup
if (props.config.model.embeddingsModels[0].provider === "bedrock") {
this.defaultEmbeddingModelName = props.config.model.embeddingsModels[0].name;
} else {
this.modelVariantName = "variantProd";
this.modelImageUrlDomain =
this.modelRegion === "cn-north-1" || this.modelRegion === "cn-northwest-1"
? ".amazonaws.com.cn/"
: ".amazonaws.com/";

this.modelPublicEcrAccount =
this.modelRegion === "cn-north-1" || this.modelRegion === "cn-northwest-1"
? "727897471807.dkr.ecr."
: "763104351884.dkr.ecr.";


// Create IAM execution role
const executionRole = new iam.Role(this, "intelli-agent-endpoint-execution-role", {
assumedBy: new iam.ServicePrincipal("sagemaker.amazonaws.com"),
managedPolicies: [
iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonSageMakerFullAccess"),
iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonS3FullAccess"),
iam.ManagedPolicy.fromAwsManagedPolicyName("CloudWatchLogsFullAccess"),
],
});
executionRole.addToPolicy(this.modelIamHelper.logStatement);
executionRole.addToPolicy(this.modelIamHelper.s3Statement);
executionRole.addToPolicy(this.modelIamHelper.endpointStatement);
executionRole.addToPolicy(this.modelIamHelper.bedrockStatement);
executionRole.addToPolicy(this.modelIamHelper.stsStatement);
executionRole.addToPolicy(this.modelIamHelper.ecrStatement);
executionRole.addToPolicy(this.modelIamHelper.llmStatement);

this.modelExecutionRole = executionRole;

if (props.config.knowledgeBase.enabled && props.config.knowledgeBase.knowledgeBaseType.intelliAgentKb.enabled) {

// Check if props.config.model.embeddingsModels includes a model with the name 'bce-embedding-and-bge-reranker'
if (props.config.model.embeddingsModels.some(model => model.name === 'bce-embedding-and-bge-reranker')) {
// Create the resource
let embeddingAndRerankerModelResources = this.deployEmbeddingAndRerankerEndpoint(props);
this.defaultEmbeddingModelName = embeddingAndRerankerModelResources.endpoint.endpointName ?? "";
}
} else if (props.config.model.embeddingsModels[0].provider === "sagemaker") {
// Initialize SageMaker-specific configurations
this.initializeSageMakerConfig();

// Set up embedding model if it's the BCE+BGE model
if (props.config.model.embeddingsModels.some(model => model.name === 'bce-embedding-and-bge-reranker')) {
const embeddingAndRerankerModelResources = this.deployEmbeddingAndRerankerEndpoint(props);
this.defaultEmbeddingModelName = embeddingAndRerankerModelResources.endpoint.endpointName ?? "";
}
}

if (props.config.knowledgeBase.knowledgeBaseType.intelliAgentKb.knowledgeBaseModel.enabled) {
let knowledgeBaseModelResources = this.deployKnowledgeBaseEndpoint(props);
// this.createKnowledgeBaseEndpointScaling(knowledgeBaseModelResources.endpoint);
this.defaultKnowledgeBaseModelName = knowledgeBaseModelResources.endpoint.endpointName ?? "";
}
// Handle knowledge base setup separately
if (props.config.knowledgeBase.enabled &&
props.config.knowledgeBase.knowledgeBaseType.intelliAgentKb.enabled &&
props.config.knowledgeBase.knowledgeBaseType.intelliAgentKb.knowledgeBaseModel.enabled) {

// Initialize SageMaker config if not already done
if (!this.modelExecutionRole) {
this.initializeSageMakerConfig();
}

// Deploy knowledge base model if enabled
const knowledgeBaseModelResources = this.deployKnowledgeBaseEndpoint(props);
this.defaultKnowledgeBaseModelName = knowledgeBaseModelResources.endpoint.endpointName ?? "";
}

if (props.config.chat.useOpenSourceLLM) {
Expand Down Expand Up @@ -384,5 +361,34 @@ export class ModelConstruct extends NestedStack implements ModelConstructOutputs

}

private initializeSageMakerConfig() {
this.modelVariantName = "variantProd";

const isChinaRegion = this.modelRegion === "cn-north-1" || this.modelRegion === "cn-northwest-1";
this.modelImageUrlDomain = isChinaRegion ? ".amazonaws.com.cn/" : ".amazonaws.com/";
this.modelPublicEcrAccount = isChinaRegion ? "727897471807.dkr.ecr." : "763104351884.dkr.ecr.";

// Create IAM execution role
const executionRole = new iam.Role(this, "intelli-agent-endpoint-execution-role", {
assumedBy: new iam.ServicePrincipal("sagemaker.amazonaws.com"),
managedPolicies: [
iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonSageMakerFullAccess"),
iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonS3FullAccess"),
iam.ManagedPolicy.fromAwsManagedPolicyName("CloudWatchLogsFullAccess"),
],
});

// Add required policies
executionRole.addToPolicy(this.modelIamHelper.logStatement);
executionRole.addToPolicy(this.modelIamHelper.s3Statement);
executionRole.addToPolicy(this.modelIamHelper.endpointStatement);
executionRole.addToPolicy(this.modelIamHelper.bedrockStatement);
executionRole.addToPolicy(this.modelIamHelper.stsStatement);
executionRole.addToPolicy(this.modelIamHelper.ecrStatement);
executionRole.addToPolicy(this.modelIamHelper.llmStatement);

this.modelExecutionRole = executionRole;
}

}

0 comments on commit 4e8f946

Please sign in to comment.