diff --git a/jest.config.js b/jest.config.js index 829b8cb4..d130bfea 100644 --- a/jest.config.js +++ b/jest.config.js @@ -5,7 +5,7 @@ module.exports = { moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'], testTimeout: 60000, // because user will cause other test fail, but we still have user spec - coveragePathIgnorePatterns: ['/milvus/User.ts'], + coveragePathIgnorePatterns: ['dist'], testPathIgnorePatterns: ['cloud.spec.ts', 'serverless.spec.ts'], // add this line testEnvironmentOptions: { NODE_ENV: 'production', diff --git a/milvus/MilvusClient.ts b/milvus/MilvusClient.ts index aaafa1e2..5903ab43 100644 --- a/milvus/MilvusClient.ts +++ b/milvus/MilvusClient.ts @@ -60,8 +60,14 @@ export class MilvusClient extends GRPCClient { logger.debug( `new client initialized, version: ${MilvusClient.sdkInfo.version} ` ); - // connect(); - this.connect(MilvusClient.sdkInfo.version); + + // If the configOrAddress is a string (i.e., the server's address), or if the configOrAddress object does not have the __SKIP_CONNECT__ property set to true, then establish a connection to the Milvus server using the current SDK version. + if ( + typeof configOrAddress === 'string' || + !(configOrAddress as ClientConfig).__SKIP_CONNECT__ + ) { + this.connect(MilvusClient.sdkInfo.version); + } } // High level API: align with pymilvus diff --git a/milvus/const/client.ts b/milvus/const/client.ts index f6e39bac..56073f7a 100644 --- a/milvus/const/client.ts +++ b/milvus/const/client.ts @@ -6,9 +6,10 @@ export enum METADATA { export enum CONNECT_STATUS { NOT_CONNECTED, - CONNECTING, - CONNECTED, + CONNECTING = 0, // GRPC channel state connecting + CONNECTED = 1, // GRPC channel state ready UNIMPLEMENTED, + SHUTDOWN = 5, // GRPC channel state shutdown } export enum TLS_MODE { diff --git a/milvus/const/defaults.ts b/milvus/const/defaults.ts index fd9dbf1f..65b36348 100644 --- a/milvus/const/defaults.ts +++ b/milvus/const/defaults.ts @@ -16,3 +16,6 @@ export const DEFAULT_DYNAMIC_FIELD = '$meta'; export const DEFAULT_COUNT_QUERY_STRING = 'count(*)'; export const DEFAULT_HTTP_TIMEOUT = 60000; // 60s export const DEFAULT_HTTP_ENDPOINT_VERSION = 'v1'; // api version, default v1 + +export const DEFAULT_POOL_MAX = 10; // default max pool client number +export const DEFAULT_POOL_MIN = 2; // default min pool client number diff --git a/milvus/grpc/BaseClient.ts b/milvus/grpc/BaseClient.ts index e4bcd65b..2a3b3fe4 100644 --- a/milvus/grpc/BaseClient.ts +++ b/milvus/grpc/BaseClient.ts @@ -1,7 +1,14 @@ import path from 'path'; import crypto from 'crypto'; import protobuf, { Root, Type } from 'protobufjs'; -import { Client, ChannelOptions } from '@grpc/grpc-js'; +import { readFileSync } from 'fs'; +import { + Client, + ChannelOptions, + credentials, + ChannelCredentials, +} from '@grpc/grpc-js'; +import { Pool } from 'generic-pool'; import { ERROR_REASONS, ClientConfig, @@ -26,18 +33,36 @@ const schemaProtoPath = path.resolve( * Base gRPC client, setup all configuration here */ export class BaseClient { + // channel pool + public channelPool!: Pool; // Client ID - clientId: string = `${crypto.randomUUID()}`; + public clientId: string = `${crypto.randomUUID()}`; // flags to indicate that if the connection is established and its state - connectStatus = CONNECT_STATUS.NOT_CONNECTED; - connectPromise = Promise.resolve(); - // metadata - protected metadata: Map = new Map(); - // The path to the Milvus protobuf file. - protected protoFilePath = { + public connectStatus = CONNECT_STATUS.NOT_CONNECTED; + // connection promise + public connectPromise = Promise.resolve(); + // TLS mode, by default it is disabled + public readonly tlsMode: TLS_MODE = TLS_MODE.DISABLED; + // The client configuration. + public readonly config: ClientConfig; + // grpc options + public readonly channelOptions: ChannelOptions; + // server info + public serverInfo: ServerInfo = {}; + // // The gRPC client instance. + // public client!: Promise; + // The timeout for connecting to the Milvus service. + public timeout: number = DEFAULT_CONNECT_TIMEOUT; + // The path to the Milvus protobuf file, user can define it from clientConfig + public protoFilePath = { milvus: milvusProtoPath, schema: schemaProtoPath, }; + + // ChannelCredentials object used for authenticating the client on the gRPC channel. + protected creds!: ChannelCredentials; + // global metadata, send each grpc request with it + protected metadata: Map = new Map(); // The protobuf schema. protected schemaProto: Root; // The Milvus protobuf. @@ -46,7 +71,6 @@ export class BaseClient { protected collectionSchemaType: Type; // The milvus field schema Type protected fieldSchemaType: Type; - // milvus proto protected readonly protoInternalPath = { serviceName: 'milvus.proto.milvus.MilvusService', @@ -54,19 +78,6 @@ export class BaseClient { fieldSchema: 'milvus.proto.schema.FieldSchema', }; - // TLS mode, by default it is disabled - public readonly tlsMode: TLS_MODE = TLS_MODE.DISABLED; - // The client configuration. - public readonly config: ClientConfig; - // grpc options - public readonly channelOptions: ChannelOptions; - // server info - public serverInfo: ServerInfo = {}; - // The gRPC client instance. - public client: Client | undefined; - // The timeout for connecting to the Milvus service. - public timeout: number = DEFAULT_CONNECT_TIMEOUT; - /** * Sets up the configuration object for the gRPC client. * @@ -163,6 +174,51 @@ export class BaseClient { this.config.tls.serverName; } + // Switch based on the TLS mode + switch (this.tlsMode) { + case TLS_MODE.ONE_WAY: + // Create SSL credentials with empty parameters for one-way authentication + this.creds = credentials.createSsl(); + break; + case TLS_MODE.TWO_WAY: + // Extract paths for root certificate, private key, certificate chain, and verify options from the client configuration + const { rootCertPath, privateKeyPath, certChainPath, verifyOptions } = + this.config.tls!; + + // Initialize buffers for root certificate, private key, and certificate chain + let rootCertBuff: Buffer | null = null; + let privateKeyBuff: Buffer | null = null; + let certChainBuff: Buffer | null = null; + + // Read root certificate file if path is provided + if (rootCertPath) { + rootCertBuff = readFileSync(rootCertPath); + } + + // Read private key file if path is provided + if (privateKeyPath) { + privateKeyBuff = readFileSync(privateKeyPath); + } + + // Read certificate chain file if path is provided + if (certChainPath) { + certChainBuff = readFileSync(certChainPath); + } + + // Create SSL credentials with the read files and verify options for two-way authentication + this.creds = credentials.createSsl( + rootCertBuff, + privateKeyBuff, + certChainBuff, + verifyOptions + ); + break; + default: + // Create insecure credentials if no TLS mode is specified + this.creds = credentials.createInsecure(); + break; + } + // Set up the timeout for connecting to the Milvus service. this.timeout = typeof config.timeout === 'string' diff --git a/milvus/grpc/Collection.ts b/milvus/grpc/Collection.ts index f27ad80d..30068471 100644 --- a/milvus/grpc/Collection.ts +++ b/milvus/grpc/Collection.ts @@ -153,7 +153,7 @@ export class Collection extends Database { // Call the promisify function to create the collection. const createPromise = await promisify( - this.client, + this.channelPool, 'CreateCollection', { ...data, @@ -201,7 +201,7 @@ export class Collection extends Database { // avoid to call describe collection, because it has cache const res = await promisify( - this.client, + this.channelPool, 'DescribeCollection', data, data.timeout || this.timeout @@ -242,7 +242,7 @@ export class Collection extends Database { data?: ShowCollectionsReq ): Promise { const promise = await promisify( - this.client, + this.channelPool, 'ShowCollections', { type: data ? data.type : ShowCollectionsType.All, @@ -290,7 +290,7 @@ export class Collection extends Database { async alterCollection(data: AlterCollectionReq): Promise { checkCollectionName(data); const promise = await promisify( - this.client, + this.channelPool, 'AlterCollection', { collection_name: data.collection_name, @@ -346,7 +346,7 @@ export class Collection extends Database { // get new data const promise = await promisify( - this.client, + this.channelPool, 'DescribeCollection', data, data.timeout || this.timeout @@ -391,7 +391,7 @@ export class Collection extends Database { checkCollectionName(data); const promise = await promisify( - this.client, + this.channelPool, 'GetCollectionStatistics', data, data.timeout || this.timeout @@ -432,7 +432,7 @@ export class Collection extends Database { checkCollectionName(data); const promise = await promisify( - this.client, + this.channelPool, 'LoadCollection', data, data.timeout || this.timeout @@ -470,7 +470,7 @@ export class Collection extends Database { checkCollectionName(data); const promise = await promisify( - this.client, + this.channelPool, 'LoadCollection', data, data.timeout || this.timeout @@ -529,7 +529,7 @@ export class Collection extends Database { checkCollectionName(data); const promise = await promisify( - this.client, + this.channelPool, 'ReleaseCollection', data, data.timeout || this.timeout @@ -564,7 +564,7 @@ export class Collection extends Database { */ async renameCollection(data: RenameCollectionReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'RenameCollection', { oldName: data.collection_name, @@ -602,7 +602,7 @@ export class Collection extends Database { checkCollectionName(data); const promise = await promisify( - this.client, + this.channelPool, 'DropCollection', data, data.timeout || this.timeout @@ -649,7 +649,7 @@ export class Collection extends Database { throw new Error(ERROR_REASONS.ALIAS_NAME_IS_REQUIRED); } const promise = await promisify( - this.client, + this.channelPool, 'CreateAlias', data, data.timeout || this.timeout @@ -688,7 +688,7 @@ export class Collection extends Database { throw new Error(ERROR_REASONS.ALIAS_NAME_IS_REQUIRED); } const promise = await promisify( - this.client, + this.channelPool, 'DropAlias', data, data.timeout || this.timeout @@ -728,7 +728,7 @@ export class Collection extends Database { throw new Error(ERROR_REASONS.ALIAS_NAME_IS_REQUIRED); } const promise = await promisify( - this.client, + this.channelPool, 'AlterAlias', data, data.timeout || this.timeout @@ -763,7 +763,7 @@ export class Collection extends Database { checkCollectionName(data); const collectionInfo = await this.describeCollection(data); const res = await promisify( - this.client, + this.channelPool, 'ManualCompaction', { collectionID: collectionInfo.collectionID, @@ -803,7 +803,7 @@ export class Collection extends Database { throw new Error(ERROR_REASONS.COMPACTION_ID_IS_REQUIRED); } const res = await promisify( - this.client, + this.channelPool, 'GetCompactionState', data, data.timeout || this.timeout @@ -841,7 +841,7 @@ export class Collection extends Database { throw new Error(ERROR_REASONS.COMPACTION_ID_IS_REQUIRED); } const res = await promisify( - this.client, + this.channelPool, 'GetCompactionStateWithPlans', data, data.timeout || this.timeout @@ -894,7 +894,7 @@ export class Collection extends Database { throw new Error(ERROR_REASONS.COLLECTION_ID_IS_REQUIRED); } const res = await promisify( - this.client, + this.channelPool, 'GetReplicas', data, data.timeout || this.timeout @@ -935,7 +935,7 @@ export class Collection extends Database { throw new Error(ERROR_REASONS.COLLECTION_NAME_IS_REQUIRED); } const res = await promisify( - this.client, + this.channelPool, 'GetLoadingProgress', data, data.timeout || this.timeout @@ -973,7 +973,7 @@ export class Collection extends Database { throw new Error(ERROR_REASONS.COLLECTION_NAME_IS_REQUIRED); } const res = await promisify( - this.client, + this.channelPool, 'GetLoadState', data, data.timeout || this.timeout diff --git a/milvus/grpc/Data.ts b/milvus/grpc/Data.ts index 5a94a3c5..cb6cd775 100644 --- a/milvus/grpc/Data.ts +++ b/milvus/grpc/Data.ts @@ -259,7 +259,7 @@ export class Data extends Collection { const timeout = typeof data.timeout === 'undefined' ? 0 : data.timeout; // execute Insert const promise = await promisify( - this.client, + this.channelPool, upsert ? 'Upsert' : 'Insert', params, timeout @@ -310,7 +310,7 @@ export class Data extends Collection { data.expr = data.filter || data.expr; const promise = await promisify( - this.client, + this.channelPool, 'Delete', data, data.timeout || this.timeout @@ -487,7 +487,7 @@ export class Data extends Collection { ).finish(); const promise: SearchRes = await promisify( - this.client, + this.channelPool, 'Search', { collection_name: data.collection_name, @@ -627,7 +627,7 @@ export class Data extends Collection { throw new Error(ERROR_REASONS.COLLECTION_NAME_IS_REQUIRED); } const res = await promisify( - this.client, + this.channelPool, 'Flush', data, data.timeout || this.timeout @@ -669,7 +669,7 @@ export class Data extends Collection { } // copy flushed collection names const res = await promisify( - this.client, + this.channelPool, 'Flush', data, data.timeout || this.timeout @@ -740,7 +740,7 @@ export class Data extends Collection { // Execute the query and get the results const promise: QueryRes = await promisify( - this.client, + this.channelPool, 'Query', { ...data, @@ -871,7 +871,7 @@ export class Data extends Collection { throw new Error(ERROR_REASONS.GET_METRIC_CHECK_PARAMS); } const res: GetMetricsResponse = await promisify( - this.client, + this.channelPool, 'GetMetrics', { request: JSON.stringify(data.request), @@ -914,7 +914,7 @@ export class Data extends Collection { throw new Error(ERROR_REASONS.GET_FLUSH_STATE_CHECK_PARAMS); } const res = await promisify( - this.client, + this.channelPool, 'GetFlushState', data, data.timeout || this.timeout @@ -954,7 +954,7 @@ export class Data extends Collection { throw new Error(ERROR_REASONS.LOAD_BALANCE_CHECK_PARAMS); } const res = await promisify( - this.client, + this.channelPool, 'LoadBalance', data, data.timeout || this.timeout @@ -994,7 +994,7 @@ export class Data extends Collection { throw new Error(ERROR_REASONS.COLLECTION_NAME_IS_REQUIRED); } const res = await promisify( - this.client, + this.channelPool, 'GetQuerySegmentInfo', data, data.timeout || this.timeout @@ -1034,7 +1034,7 @@ export class Data extends Collection { throw new Error(ERROR_REASONS.COLLECTION_NAME_IS_REQUIRED); } const res = await promisify( - this.client, + this.channelPool, 'GetPersistentSegmentInfo', data, data.timeout || this.timeout @@ -1078,7 +1078,7 @@ export class Data extends Collection { throw new Error(ERROR_REASONS.IMPORT_FILE_CHECK); } const res = await promisify( - this.client, + this.channelPool, 'Import', { ...data, @@ -1126,7 +1126,7 @@ export class Data extends Collection { throw new Error(ERROR_REASONS.COLLECTION_NAME_IS_REQUIRED); } const res = await promisify( - this.client, + this.channelPool, 'ListImportTasks', { ...data, @@ -1171,7 +1171,7 @@ export class Data extends Collection { // } // const res = await promisify( - // this.client, + // this.channelPool, // 'ListIndexedSegment', // data, // data.timeout || this.timeout @@ -1213,7 +1213,7 @@ export class Data extends Collection { // } // const res = await promisify( - // this.client, + // this.channelPool, // 'DescribeSegmentIndexData', // data, // data.timeout || this.timeout diff --git a/milvus/grpc/Database.ts b/milvus/grpc/Database.ts index 3322c33f..d3927c1d 100644 --- a/milvus/grpc/Database.ts +++ b/milvus/grpc/Database.ts @@ -37,7 +37,7 @@ export class Database extends BaseClient { }); const promise = await promisify( - this.client, + this.channelPool, 'CreateDatabase', data, data.timeout || this.timeout @@ -74,7 +74,7 @@ export class Database extends BaseClient { }); const promise = await promisify( - this.client, + this.channelPool, 'ListDatabases', {}, data?.timeout || this.timeout @@ -110,7 +110,7 @@ export class Database extends BaseClient { }); const promise = await promisify( - this.client, + this.channelPool, 'DropDatabase', data, data.timeout || this.timeout diff --git a/milvus/grpc/GrpcClient.ts b/milvus/grpc/GrpcClient.ts index eccf0545..287dc85a 100644 --- a/milvus/grpc/GrpcClient.ts +++ b/milvus/grpc/GrpcClient.ts @@ -1,6 +1,11 @@ -import { readFileSync } from 'fs'; -import { credentials, Metadata, ChannelCredentials } from '@grpc/grpc-js'; +import { + Metadata, + ServiceClientConstructor, + ChannelOptions, + Client, +} from '@grpc/grpc-js'; import dayjs from 'dayjs'; +import { createPool } from 'generic-pool'; import { GetVersionResponse, CheckHealthResponse, @@ -17,7 +22,9 @@ import { METADATA, logger, CONNECT_STATUS, - TLS_MODE, + ClientConfig, + DEFAULT_POOL_MAX, + DEFAULT_POOL_MIN, } from '../'; import { User } from './User'; @@ -25,14 +32,39 @@ import { User } from './User'; * A client for interacting with the Milvus server via gRPC. */ export class GRPCClient extends User { - // create a grpc service client(connect) - connect(sdkVersion: string) { - // get Milvus GRPC service + /** + * Creates a new instance of MilvusClient. + * @param configOrAddress The Milvus server's address or client configuration object. + * @param ssl Whether to use SSL or not. + * @param username The username for authentication. + * @param password The password for authentication. + * @param channelOptions Additional channel options for gRPC. + */ + constructor( + configOrAddress: ClientConfig | string, + ssl?: boolean, + username?: string, + password?: string, + channelOptions?: ChannelOptions + ) { + // setup the configuration + super(configOrAddress, ssl, username, password, channelOptions); + + // Get the gRPC service for Milvus const MilvusService = getGRPCService({ protoPath: this.protoFilePath.milvus, serviceName: this.protoInternalPath.serviceName, // the name of the Milvus service }); + // setup auth if necessary + const auth = getAuthString(this.config); + if (auth.length > 0) { + this.metadata.set(METADATA.AUTH, auth); + } + + // setup database + this.metadata.set(METADATA.DATABASE, this.config.database || DEFAULT_DB); + // meta interceptor, add the injector const metaInterceptor = getMetaInterceptor( this.metadataListener.bind(this) @@ -50,76 +82,61 @@ export class GRPCClient extends User { : this.config.retryDelay, clientId: this.clientId, }); + // interceptors const interceptors = [metaInterceptor, retryInterceptor]; // add interceptors this.channelOptions.interceptors = interceptors; - // setup auth if necessary - const auth = getAuthString(this.config); - if (auth.length > 0) { - this.metadata.set(METADATA.AUTH, auth); - } - - // setup database - this.metadata.set(METADATA.DATABASE, this.config.database || DEFAULT_DB); - - // create credentials - let creds: ChannelCredentials; - - // assign credentials according to the tls mode - switch (this.tlsMode) { - case TLS_MODE.ONE_WAY: - // create ssl with empty parameters - creds = credentials.createSsl(); - break; - case TLS_MODE.TWO_WAY: - const { rootCertPath, privateKeyPath, certChainPath, verifyOptions } = - this.config.tls!; - - // init - let rootCertBuff: Buffer | null = null; - let privateKeyBuff: Buffer | null = null; - let certChainBuff: Buffer | null = null; - - // read root cert file - if (rootCertPath) { - rootCertBuff = readFileSync(rootCertPath); - } - - // read private key file - if (privateKeyPath) { - privateKeyBuff = readFileSync(privateKeyPath); - } + // create grpc pool + this.channelPool = this.createChannelPool(MilvusService); + } - // read cert chain file - if (certChainPath) { - certChainBuff = readFileSync(certChainPath); - } + // create a grpc service client(connect) + connect(sdkVersion: string) { + // connect to get identifier + this.connectPromise = this._getServerInfo(sdkVersion); + } - // create credentials - creds = credentials.createSsl( - rootCertBuff, - privateKeyBuff, - certChainBuff, - verifyOptions - ); - break; - default: - creds = credentials.createInsecure(); - break; - } + // return client acquired from pool + get client() { + return this.channelPool.acquire(); + } - // create grpc client - this.client = new MilvusService( - formatAddress(this.config.address), // format the address - creds, - this.channelOptions + /** + * Creates a pool of gRPC service clients. + * + * @param {ServiceClientConstructor} ServiceClientConstructor - The constructor for the gRPC service client. + * + * @returns {Pool} - A pool of gRPC service clients. + */ + private createChannelPool( + ServiceClientConstructor: ServiceClientConstructor + ) { + return createPool( + { + create: async () => { + // Create a new gRPC service client + return new ServiceClientConstructor( + formatAddress(this.config.address), // format the address + this.creds, + this.channelOptions + ); + }, + destroy: async (client: Client) => { + // Close the gRPC service client + return new Promise((resolve, reject) => { + client.close(); + resolve(client.getChannel().getConnectivityState(true)); + }); + }, + }, + this.config.pool ?? { + min: DEFAULT_POOL_MIN, + max: DEFAULT_POOL_MAX, + } ); - - // connect to get identifier - this.connectPromise = this._getServerInfo(sdkVersion); } /** @@ -177,43 +194,40 @@ export class GRPCClient extends User { // update connect status this.connectStatus = CONNECT_STATUS.CONNECTING; - return promisify(this.client, 'Connect', userInfo, this.timeout).then(f => { - // add new identifier interceptor - if (f && f.identifier) { - // update identifier - this.metadata.set(METADATA.CLIENT_ID, f.identifier); + return promisify(this.channelPool, 'Connect', userInfo, this.timeout).then( + f => { + // add new identifier interceptor + if (f && f.identifier) { + // update identifier + this.metadata.set(METADATA.CLIENT_ID, f.identifier); - // setup identifier - this.serverInfo = f.server_info; + // setup identifier + this.serverInfo = f.server_info; + } + // update connect status + this.connectStatus = + f && f.identifier + ? CONNECT_STATUS.CONNECTED + : CONNECT_STATUS.UNIMPLEMENTED; } - // update connect status - this.connectStatus = - f && f.identifier - ? CONNECT_STATUS.CONNECTED - : CONNECT_STATUS.UNIMPLEMENTED; - }); + ); } /** - * Closes the gRPC client connection and returns the connectivity state of the channel. - * This method should be called before terminating the application or when the client is no longer needed. - * This method returns a number that represents the connectivity state of the channel: - * - 0: CONNECTING - * - 1: READY - * - 2: IDLE - * - 3: TRANSIENT FAILURE - * - 4: FATAL FAILURE - * - 5: SHUTDOWN + * Closes the connection to the Milvus server. + * This method drains and clears the connection pool, and updates the connection status to SHUTDOWN. + * @returns {Promise} The updated connection status. */ - closeConnection() { - // Close the gRPC client connection - if (this.client) { - this.client.close(); - } - // grpc client closed -> 4, connected -> 0 - if (this.client) { - return this.client.getChannel().getConnectivityState(true); + async closeConnection() { + // Close all connections in the pool + if (this.channelPool) { + await this.channelPool.drain(); + await this.channelPool.clear(); + + // update status + this.connectStatus = CONNECT_STATUS.SHUTDOWN; } + return this.connectStatus; } /** @@ -221,7 +235,7 @@ export class GRPCClient extends User { * This method returns a Promise that resolves with a `GetVersionResponse` object. */ async getVersion(): Promise { - return await promisify(this.client, 'GetVersion', {}, this.timeout); + return await promisify(this.channelPool, 'GetVersion', {}, this.timeout); } /** @@ -229,6 +243,6 @@ export class GRPCClient extends User { * This method returns a Promise that resolves with a `CheckHealthResponse` object. */ async checkHealth(): Promise { - return await promisify(this.client, 'CheckHealth', {}, this.timeout); + return await promisify(this.channelPool, 'CheckHealth', {}, this.timeout); } } diff --git a/milvus/grpc/MilvusIndex.ts b/milvus/grpc/MilvusIndex.ts index 607880f2..c31ac9b3 100644 --- a/milvus/grpc/MilvusIndex.ts +++ b/milvus/grpc/MilvusIndex.ts @@ -88,7 +88,7 @@ export class Index extends Data { // Call the 'CreateIndex' gRPC method and return the result const promise = await promisify( - this.client, + this.channelPool, 'CreateIndex', createIndexParams, data.timeout || this.timeout @@ -124,7 +124,7 @@ export class Index extends Data { async describeIndex(data: DescribeIndexReq): Promise { checkCollectionName(data); const promise = await promisify( - this.client, + this.channelPool, 'DescribeIndex', data, data.timeout || this.timeout @@ -160,7 +160,7 @@ export class Index extends Data { async getIndexState(data: GetIndexStateReq): Promise { checkCollectionName(data); const promise = await promisify( - this.client, + this.channelPool, 'GetIndexState', data, data.timeout || this.timeout @@ -201,7 +201,7 @@ export class Index extends Data { ): Promise { checkCollectionName(data); const promise = await promisify( - this.client, + this.channelPool, 'GetIndexBuildProgress', data, data.timeout || this.timeout @@ -239,7 +239,7 @@ export class Index extends Data { async dropIndex(data: DropIndexReq): Promise { checkCollectionName(data); const promise = await promisify( - this.client, + this.channelPool, 'DropIndex', data, data.timeout || this.timeout diff --git a/milvus/grpc/Partition.ts b/milvus/grpc/Partition.ts index 20e78d5d..88d8139f 100644 --- a/milvus/grpc/Partition.ts +++ b/milvus/grpc/Partition.ts @@ -48,7 +48,7 @@ export class Partition extends Index { async createPartition(data: CreatePartitionReq): Promise { checkCollectionAndPartitionName(data); const promise = await promisify( - this.client, + this.channelPool, 'CreatePartition', data, data.timeout || this.timeout @@ -85,7 +85,7 @@ export class Partition extends Index { async hasPartition(data: HasPartitionReq): Promise { checkCollectionAndPartitionName(data); const promise = await promisify( - this.client, + this.channelPool, 'HasPartition', data, data.timeout || this.timeout @@ -124,7 +124,7 @@ export class Partition extends Index { ): Promise { checkCollectionName(data); const promise = await promisify( - this.client, + this.channelPool, 'ShowPartitions', data, data.timeout || this.timeout @@ -165,7 +165,7 @@ export class Partition extends Index { ): Promise { checkCollectionAndPartitionName(data); const promise = await promisify( - this.client, + this.channelPool, 'GetPartitionStatistics', data, data.timeout || this.timeout @@ -207,7 +207,7 @@ export class Partition extends Index { throw new Error(ERROR_REASONS.PARTITION_NAMES_IS_REQUIRED); } const promise = await promisify( - this.client, + this.channelPool, 'LoadPartitions', data, data.timeout || this.timeout @@ -247,7 +247,7 @@ export class Partition extends Index { throw new Error(ERROR_REASONS.PARTITION_NAMES_IS_REQUIRED); } const promise = await promisify( - this.client, + this.channelPool, 'ReleasePartitions', data, data.timeout || this.timeout @@ -290,7 +290,7 @@ export class Partition extends Index { async dropPartition(data: DropPartitionReq): Promise { checkCollectionAndPartitionName(data); const promise = await promisify( - this.client, + this.channelPool, 'DropPartition', data, data.timeout || this.timeout diff --git a/milvus/grpc/Resource.ts b/milvus/grpc/Resource.ts index 997f3b7a..36ce0733 100644 --- a/milvus/grpc/Resource.ts +++ b/milvus/grpc/Resource.ts @@ -39,7 +39,7 @@ export class Resource extends Partition { */ async createResourceGroup(data: CreateResourceGroupReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'CreateResourceGroup', data, data.timeout || this.timeout @@ -68,7 +68,7 @@ export class Resource extends Partition { data?: GrpcTimeOut ): Promise { const promise = await promisify( - this.client, + this.channelPool, 'ListResourceGroups', {}, data?.timeout || this.timeout @@ -108,7 +108,7 @@ export class Resource extends Partition { data: DescribeResourceGroupsReq ): Promise { const promise = await promisify( - this.client, + this.channelPool, 'DescribeResourceGroup', data, data.timeout || this.timeout @@ -140,7 +140,7 @@ export class Resource extends Partition { */ async dropResourceGroup(data: DropResourceGroupsReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'DropResourceGroup', data, data.timeout || this.timeout @@ -180,7 +180,7 @@ export class Resource extends Partition { /* istanbul ignore next */ async transferReplica(data: TransferReplicaReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'TransferReplica', data, data.timeout || this.timeout @@ -218,7 +218,7 @@ export class Resource extends Partition { /* istanbul ignore next */ async transferNode(data: TransferNodeReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'TransferNode', data, data.timeout || this.timeout diff --git a/milvus/grpc/User.ts b/milvus/grpc/User.ts index 062056ae..e836fa56 100644 --- a/milvus/grpc/User.ts +++ b/milvus/grpc/User.ts @@ -61,7 +61,7 @@ export class User extends Resource { } const encryptedPassword = stringToBase64(data.password); const promise = await promisify( - this.client, + this.channelPool, 'CreateCredential', { username: data.username, @@ -110,7 +110,7 @@ export class User extends Resource { const encryptedNewPwd = stringToBase64(data.newPassword); const promise = await promisify( - this.client, + this.channelPool, 'UpdateCredential', { username: data.username, @@ -150,7 +150,7 @@ export class User extends Resource { throw new Error(ERROR_REASONS.USERNAME_IS_REQUIRED); } const promise = await promisify( - this.client, + this.channelPool, 'DeleteCredential', { username: data.username, @@ -182,7 +182,7 @@ export class User extends Resource { */ async listUsers(data?: ListUsersReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'ListCredUsers', {}, data?.timeout || this.timeout @@ -213,7 +213,7 @@ export class User extends Resource { */ async createRole(data: CreateRoleReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'CreateRole', { entity: { name: data.roleName }, @@ -246,7 +246,7 @@ export class User extends Resource { */ async dropRole(data: DropRoleReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'DropRole', { role_name: data.roleName, @@ -280,7 +280,7 @@ export class User extends Resource { */ async addUserToRole(data: AddUserToRoleReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'OperateUserRole', { username: data.username, @@ -316,7 +316,7 @@ export class User extends Resource { */ async removeUserFromRole(data: RemoveUserFromRoleReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'OperateUserRole', { username: data.username, @@ -352,7 +352,7 @@ export class User extends Resource { */ async selectRole(data: SelectRoleReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'SelectRole', { role: { name: data.roleName }, @@ -386,7 +386,7 @@ export class User extends Resource { */ async listRoles(data?: listRoleReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'SelectRole', { include_user_info: data?.includeUserInfo || true, @@ -420,7 +420,7 @@ export class User extends Resource { */ async selectUser(data: SelectUserReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'SelectUser', { user: { name: data.username }, @@ -463,7 +463,7 @@ export class User extends Resource { */ async grantRolePrivilege(data: OperateRolePrivilegeReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'OperatePrivilege', { entity: { @@ -513,7 +513,7 @@ export class User extends Resource { */ async revokeRolePrivilege(data: OperateRolePrivilegeReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'OperatePrivilege', { entity: { @@ -621,7 +621,7 @@ export class User extends Resource { */ async selectGrant(data: SelectGrantReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'SelectGrant', { entity: { @@ -663,7 +663,7 @@ export class User extends Resource { */ async listGrants(data: ListGrantsReq): Promise { const promise = await promisify( - this.client, + this.channelPool, 'SelectGrant', { entity: { diff --git a/milvus/types/Client.ts b/milvus/types/Client.ts index 7a5a45a2..033e08a0 100644 --- a/milvus/types/Client.ts +++ b/milvus/types/Client.ts @@ -1,4 +1,5 @@ import { ChannelOptions } from '@grpc/grpc-js'; +import { Options } from 'generic-pool'; /** * Configuration options for the Milvus client. @@ -46,6 +47,12 @@ export interface ClientConfig { // server name serverName?: string; }; + + // generic-pool options: refer to https://github.com/coopernurse/node-pool + pool?: Options; + + // internal property for debug & test + __SKIP_CONNECT__?: boolean; } export interface ServerInfo { diff --git a/milvus/utils/Function.ts b/milvus/utils/Function.ts index e9bf95ea..cc797b63 100644 --- a/milvus/utils/Function.ts +++ b/milvus/utils/Function.ts @@ -1,4 +1,5 @@ import { KeyValuePair, DataType, ERROR_REASONS } from '../'; +import { Pool } from 'generic-pool'; /** * Promisify a function call with optional timeout @@ -8,8 +9,8 @@ import { KeyValuePair, DataType, ERROR_REASONS } from '../'; * @param timeout - Optional timeout in milliseconds * @returns A Promise that resolves with the result of the target function call */ -export function promisify( - obj: any, +export async function promisify( + pool: Pool, target: string, params: any, timeout: number @@ -17,11 +18,14 @@ export function promisify( // Calculate the deadline for the function call const t = timeout === 0 ? 1000 * 60 * 60 * 24 : timeout; + // get client + const client = await pool.acquire(); + // Create a new Promise that wraps the target function call - const res = new Promise((resolve, reject) => { + return new Promise((resolve, reject) => { try { // Call the target function with the provided parameters and deadline - obj[target]( + client[target]( params, { deadline: new Date(Date.now() + t) }, (err: any, result: any) => { @@ -34,16 +38,13 @@ export function promisify( } ); } catch (e: any) { - // If there was an exception, throw a new Error - throw new Error(e); + reject(e); + } finally { + if (client) { + pool.release(client); + } } - }).catch(err => { - // Return a rejected Promise with the error - return Promise.reject(err); }); - - // Return the Promise - return res; } export const findKeyValue = (obj: KeyValuePair[], key: string) => diff --git a/package.json b/package.json index b295574f..bb650b92 100644 --- a/package.json +++ b/package.json @@ -23,6 +23,7 @@ "@grpc/grpc-js": "1.8.17", "@grpc/proto-loader": "0.7.7", "dayjs": "^1.11.7", + "generic-pool": "^3.9.0", "lru-cache": "^9.1.2", "protobufjs": "7.2.4", "winston": "^3.9.0" diff --git a/test/grpc/Collection.spec.ts b/test/grpc/Collection.spec.ts index 97acd30c..538ff54d 100644 --- a/test/grpc/Collection.spec.ts +++ b/test/grpc/Collection.spec.ts @@ -513,7 +513,9 @@ describe(`Collection API`, () => { collection_name: LOAD_COLLECTION_NAME, }); - expect(Number(formatKeyValueData(describe.properties, [key])[key])).toEqual(value); + expect(Number(formatKeyValueData(describe.properties, [key])[key])).toEqual( + value + ); }); it(`Create alias success`, async () => { diff --git a/test/grpc/MilvusClient.spec.ts b/test/grpc/MilvusClient.spec.ts index 030f44b5..c7ae8684 100644 --- a/test/grpc/MilvusClient.spec.ts +++ b/test/grpc/MilvusClient.spec.ts @@ -1,4 +1,10 @@ -import { MilvusClient, ERROR_REASONS, CONNECT_STATUS } from '../../milvus'; +import path from 'path'; +import { + MilvusClient, + ERROR_REASONS, + CONNECT_STATUS, + TLS_MODE, +} from '../../milvus'; import sdkInfo from '../../sdk.json'; import { IP } from '../tools'; @@ -6,60 +12,85 @@ const milvusClient = new MilvusClient({ address: IP, }); +// path +const milvusProtoPath = path.resolve( + __dirname, + '../../proto/proto/milvus.proto' +); +const schemaProtoPath = path.resolve( + __dirname, + '../../proto/proto/schema.proto' +); + describe(`Milvus client`, () => { afterEach(() => { jest.clearAllMocks(); }); - // it(`should create a grpc client with cert file successfully`, async () => { - // const milvusClient = new MilvusClient({ - // address: IP, - // tls: { - // rootCertPath: `test/cert/ca.pem`, - // privateKeyPath: `test/cert/client.key`, - // certChainPath: `test/cert/client.pem`, - // serverName: IP, - // }, - // id: '1', - // }); - - // expect(milvusClient.client).toBeDefined(); - // expect(milvusClient.tlsMode).toEqual(2); - // expect(milvusClient.clientId).toEqual('1'); - // }); - - it(`should create a grpc client without SSL credentials when ssl is false`, () => { - const milvusClient = new MilvusClient({ + it(`should create a grpc client with cert file successfully`, async () => { + const m1 = new MilvusClient({ + address: IP, + tls: { + rootCertPath: `test/cert/ca.pem`, + privateKeyPath: `test/cert/client.key`, + certChainPath: `test/cert/client.pem`, + serverName: IP, + }, + id: '1', + __SKIP_CONNECT__: true, + }); + + expect(await m1.client).toBeDefined(); + expect(m1.tlsMode).toEqual(TLS_MODE.TWO_WAY); + expect(m1.clientId).toEqual('1'); + }); + + it(`should create a grpc client without SSL credentials when ssl is false`, async () => { + const m2 = new MilvusClient({ address: IP, - ssl: false, + ssl: true, username: 'username', password: 'password', id: '1', + __SKIP_CONNECT__: true, }); - expect(milvusClient.clientId).toEqual('1'); - expect(milvusClient.client).toBeDefined(); + expect(m2.clientId).toEqual('1'); + expect(await m2.client).toBeDefined(); + expect(m2.tlsMode).toEqual(TLS_MODE.ONE_WAY); }); it(`should create a grpc client without authentication when username and password are not provided`, () => { - const milvusClient = new MilvusClient(IP, false); - - expect(milvusClient.client).toBeDefined(); + const m3 = new MilvusClient(IP, false); + expect(m3.client).toBeDefined(); }); it(`should have connect promise and connectStatus`, async () => { - const milvusClient = new MilvusClient(IP, false); - expect(milvusClient.connectPromise).toBeDefined(); + const m4 = new MilvusClient(IP, false); + expect(m4.connectPromise).toBeDefined(); - await milvusClient.connectPromise; - expect(milvusClient.connectStatus).not.toEqual( - CONNECT_STATUS.NOT_CONNECTED - ); + await m4.connectPromise; + expect(m4.connectStatus).not.toEqual(CONNECT_STATUS.NOT_CONNECTED); }); - it(`should create a grpc client with authentication when username and password are provided`, () => { - const milvusClient = new MilvusClient(IP, false, `username`, `password`); - expect(milvusClient.client).toBeDefined(); + it(`should create a grpc client with authentication when username and password are provided`, async () => { + const m5 = new MilvusClient(IP, false, `username`, `password`); + expect(await m5.client).toBeDefined(); + }); + + it(`should setup protofile path successfully`, async () => { + const m6 = new MilvusClient({ + address: IP, + protoFilePath: { + milvus: milvusProtoPath, + schema: schemaProtoPath, + }, + __SKIP_CONNECT__: true, + }); + + expect(await m6.client).toBeDefined(); + expect(m6.protoFilePath.milvus).toEqual(milvusProtoPath); + expect(m6.protoFilePath.schema).toEqual(schemaProtoPath); }); it(`Should throw MILVUS_ADDRESS_IS_REQUIRED`, async () => { @@ -88,7 +119,10 @@ describe(`Milvus client`, () => { }); it(`Expect close connection success`, async () => { - const res = milvusClient.closeConnection(); - expect(res).toEqual(4); + expect(milvusClient.channelPool.size).toBeGreaterThan(0); + + const res = await milvusClient.closeConnection(); + expect(milvusClient.channelPool.size).toBe(0); + expect(res).toBe(CONNECT_STATUS.SHUTDOWN); }); }); diff --git a/test/utils/Function.spec.ts b/test/utils/Function.spec.ts index 3fa45c32..7cb61466 100644 --- a/test/utils/Function.spec.ts +++ b/test/utils/Function.spec.ts @@ -1,56 +1,49 @@ -import { promisify } from '../../milvus'; +import { promisify } from '../../milvus/utils'; -describe(`utils/function`, () => { - it('should resolve with the result of the target function call', async () => { - const obj = { - target: (params: any, options: any, callback: any) => { - callback(null, 'result'); - }, +describe('promisify', () => { + let pool: any; + let client: any; + + beforeEach(() => { + client = { + testFunction: jest.fn((params, options, callback) => + callback(null, 'success') + ), + }; + pool = { + acquire: jest.fn().mockResolvedValue(client), + release: jest.fn(), }; - const target = 'target'; - const params = {}; - const timeout = 1000; - const result = await promisify(obj, target, params, timeout); - expect(result).toEqual('result'); }); - it('should reject with the error if there was an error', async () => { - const obj = { - target: (params: any, options: any, callback: any) => { - callback(new Error('error')); - }, - }; - const target = 'target'; - const params = {}; - const timeout = 1000; - await expect(promisify(obj, target, params, timeout)).rejects.toThrow( + it('should resolve with the result of the function call', async () => { + const result = await promisify(pool, 'testFunction', {}, 1000); + expect(result).toBe('success'); + expect(client.testFunction).toHaveBeenCalled(); + expect(pool.acquire).toHaveBeenCalled(); + expect(pool.release).toHaveBeenCalled(); + }); + + it('should reject if the function call results in an error', async () => { + client.testFunction = jest.fn((params, options, callback) => + callback('error') + ); + await expect(promisify(pool, 'testFunction', {}, 1000)).rejects.toBe( 'error' ); + expect(client.testFunction).toHaveBeenCalled(); + expect(pool.acquire).toHaveBeenCalled(); + expect(pool.release).toHaveBeenCalled(); }); - it('should reject with the error if there was an exception', async () => { - const obj = { - target: () => { - throw new Error('exception'); - }, - }; - const target = 'target'; - const params = {}; - const timeout = 1000; - await expect(promisify(obj, target, params, timeout)).rejects.toThrow( + it('should reject if the function call throws an exception', async () => { + client.testFunction = jest.fn(() => { + throw new Error('exception'); + }); + await expect(promisify(pool, 'testFunction', {}, 1000)).rejects.toThrow( 'exception' ); - }); - - it('should use the default timeout if no timeout is provided', async () => { - const obj = { - target: (params: any, options: any, callback: any) => { - callback(null, 'result'); - }, - }; - const target = 'target'; - const params = {}; - const result = await promisify(obj, target, params, 0); - expect(result).toEqual('result'); + expect(pool.acquire).toHaveBeenCalled(); + expect(pool.release).toHaveBeenCalled(); }); }); diff --git a/yarn.lock b/yarn.lock index c8d1b0d7..69db35a1 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1472,6 +1472,11 @@ function-bind@^1.1.1: resolved "https://registry.yarnpkg.com/function-bind/-/function-bind-1.1.1.tgz#a56899d3ea3c9bab874bb9773b7c5ede92f4895d" integrity sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A== +generic-pool@^3.9.0: + version "3.9.0" + resolved "https://registry.yarnpkg.com/generic-pool/-/generic-pool-3.9.0.tgz#36f4a678e963f4fdb8707eab050823abc4e8f5e4" + integrity sha512-hymDOu5B53XvN4QT9dBmZxPX4CWhBPPLguTZ9MMFeFa/Kg0xWVfylOVNlJji/E7yTZWFd/q9GO5TxDLq156D7g== + gensync@^1.0.0-beta.2: version "1.0.0-beta.2" resolved "https://registry.yarnpkg.com/gensync/-/gensync-1.0.0-beta.2.tgz#32a6ee76c3d7f52d46b2b1ae5d93fea8580a25e0"