diff --git a/source/infrastructure/lib/model/model-construct.ts b/source/infrastructure/lib/model/model-construct.ts index 6d22f93f..99bdcf6c 100644 --- a/source/infrastructure/lib/model/model-construct.ts +++ b/source/infrastructure/lib/model/model-construct.ts @@ -72,7 +72,7 @@ export class ModelConstruct extends NestedStack implements ModelConstructOutputs modelAccount = Aws.ACCOUNT_ID; modelRegion = Aws.REGION; modelIamHelper: IAMHelper; - modelExecutionRole?: iam.Role; + modelExecutionRole?: iam.Role = undefined; modelImageUrlDomain?: string; modelPublicEcrAccount?: string; modelVariantName?: string; @@ -81,56 +81,33 @@ export class ModelConstruct extends NestedStack implements ModelConstructOutputs super(scope, id); this.modelIamHelper = props.sharedConstructOutputs.iamHelper; - // this.defaultEmbeddingModelName = "cohere.embed-english-v3" + // handle embedding model name setup if (props.config.model.embeddingsModels[0].provider === "bedrock") { this.defaultEmbeddingModelName = props.config.model.embeddingsModels[0].name; - } else { - this.modelVariantName = "variantProd"; - this.modelImageUrlDomain = - this.modelRegion === "cn-north-1" || this.modelRegion === "cn-northwest-1" - ? ".amazonaws.com.cn/" - : ".amazonaws.com/"; - - this.modelPublicEcrAccount = - this.modelRegion === "cn-north-1" || this.modelRegion === "cn-northwest-1" - ? "727897471807.dkr.ecr." - : "763104351884.dkr.ecr."; - - - // Create IAM execution role - const executionRole = new iam.Role(this, "intelli-agent-endpoint-execution-role", { - assumedBy: new iam.ServicePrincipal("sagemaker.amazonaws.com"), - managedPolicies: [ - iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonSageMakerFullAccess"), - iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonS3FullAccess"), - iam.ManagedPolicy.fromAwsManagedPolicyName("CloudWatchLogsFullAccess"), - ], - }); - executionRole.addToPolicy(this.modelIamHelper.logStatement); - executionRole.addToPolicy(this.modelIamHelper.s3Statement); - executionRole.addToPolicy(this.modelIamHelper.endpointStatement); - executionRole.addToPolicy(this.modelIamHelper.bedrockStatement); - executionRole.addToPolicy(this.modelIamHelper.stsStatement); - executionRole.addToPolicy(this.modelIamHelper.ecrStatement); - executionRole.addToPolicy(this.modelIamHelper.llmStatement); - - this.modelExecutionRole = executionRole; - - if (props.config.knowledgeBase.enabled && props.config.knowledgeBase.knowledgeBaseType.intelliAgentKb.enabled) { - - // Check if props.config.model.embeddingsModels includes a model with the name 'bce-embedding-and-bge-reranker' - if (props.config.model.embeddingsModels.some(model => model.name === 'bce-embedding-and-bge-reranker')) { - // Create the resource - let embeddingAndRerankerModelResources = this.deployEmbeddingAndRerankerEndpoint(props); - this.defaultEmbeddingModelName = embeddingAndRerankerModelResources.endpoint.endpointName ?? ""; - } + } else if (props.config.model.embeddingsModels[0].provider === "sagemaker") { + // Initialize SageMaker-specific configurations + this.initializeSageMakerConfig(); + + // Set up embedding model if it's the BCE+BGE model + if (props.config.model.embeddingsModels.some(model => model.name === 'bce-embedding-and-bge-reranker')) { + const embeddingAndRerankerModelResources = this.deployEmbeddingAndRerankerEndpoint(props); + this.defaultEmbeddingModelName = embeddingAndRerankerModelResources.endpoint.endpointName ?? ""; + } + } - if (props.config.knowledgeBase.knowledgeBaseType.intelliAgentKb.knowledgeBaseModel.enabled) { - let knowledgeBaseModelResources = this.deployKnowledgeBaseEndpoint(props); - // this.createKnowledgeBaseEndpointScaling(knowledgeBaseModelResources.endpoint); - this.defaultKnowledgeBaseModelName = knowledgeBaseModelResources.endpoint.endpointName ?? ""; - } + // Handle knowledge base setup separately + if (props.config.knowledgeBase.enabled && + props.config.knowledgeBase.knowledgeBaseType.intelliAgentKb.enabled && + props.config.knowledgeBase.knowledgeBaseType.intelliAgentKb.knowledgeBaseModel.enabled) { + + // Initialize SageMaker config if not already done + if (!this.modelExecutionRole) { + this.initializeSageMakerConfig(); } + + // Deploy knowledge base model if enabled + const knowledgeBaseModelResources = this.deployKnowledgeBaseEndpoint(props); + this.defaultKnowledgeBaseModelName = knowledgeBaseModelResources.endpoint.endpointName ?? ""; } if (props.config.chat.useOpenSourceLLM) { @@ -384,5 +361,34 @@ export class ModelConstruct extends NestedStack implements ModelConstructOutputs } + private initializeSageMakerConfig() { + this.modelVariantName = "variantProd"; + + const isChinaRegion = this.modelRegion === "cn-north-1" || this.modelRegion === "cn-northwest-1"; + this.modelImageUrlDomain = isChinaRegion ? ".amazonaws.com.cn/" : ".amazonaws.com/"; + this.modelPublicEcrAccount = isChinaRegion ? "727897471807.dkr.ecr." : "763104351884.dkr.ecr."; + + // Create IAM execution role + const executionRole = new iam.Role(this, "intelli-agent-endpoint-execution-role", { + assumedBy: new iam.ServicePrincipal("sagemaker.amazonaws.com"), + managedPolicies: [ + iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonSageMakerFullAccess"), + iam.ManagedPolicy.fromAwsManagedPolicyName("AmazonS3FullAccess"), + iam.ManagedPolicy.fromAwsManagedPolicyName("CloudWatchLogsFullAccess"), + ], + }); + + // Add required policies + executionRole.addToPolicy(this.modelIamHelper.logStatement); + executionRole.addToPolicy(this.modelIamHelper.s3Statement); + executionRole.addToPolicy(this.modelIamHelper.endpointStatement); + executionRole.addToPolicy(this.modelIamHelper.bedrockStatement); + executionRole.addToPolicy(this.modelIamHelper.stsStatement); + executionRole.addToPolicy(this.modelIamHelper.ecrStatement); + executionRole.addToPolicy(this.modelIamHelper.llmStatement); + + this.modelExecutionRole = executionRole; + } + }