From a44a4b164310b2db959aac7b5577163f525d18d6 Mon Sep 17 00:00:00 2001 From: ryjiang Date: Thu, 24 Oct 2024 10:55:19 +0800 Subject: [PATCH] [2.5] support functions (#364) * WIP: functions Signed-off-by: ryjiang * WIP: functions Signed-off-by: ryjiang * WIP Signed-off-by: ryjiang * WIP Signed-off-by: ryjiang * suppport functions part2 Signed-off-by: ryjiang * functions part3 Signed-off-by: ryjiang * finish functions Signed-off-by: ryjiang * update test Signed-off-by: ryjiang --------- Signed-off-by: ryjiang --- milvus/const/milvus.ts | 7 ++ milvus/grpc/BaseClient.ts | 14 +-- milvus/grpc/Collection.ts | 19 ++- milvus/grpc/Data.ts | 2 + milvus/types/Collection.ts | 33 +++-- milvus/types/Data.ts | 17 ++- milvus/utils/Bytes.ts | 61 +++++---- milvus/utils/Format.ts | 132 ++++++++++++++------ milvus/utils/Grpc.ts | 4 +- package.json | 2 +- proto | 2 +- test/grpc/Functions.spec.ts | 242 ++++++++++++++++++++++++++++++++++++ test/tools/collection.ts | 16 ++- test/tools/data.ts | 5 + test/utils/Format.spec.ts | 146 +++++++++++++--------- 15 files changed, 545 insertions(+), 157 deletions(-) create mode 100644 test/grpc/Functions.spec.ts diff --git a/milvus/const/milvus.ts b/milvus/const/milvus.ts index a76bdabf..977101ef 100644 --- a/milvus/const/milvus.ts +++ b/milvus/const/milvus.ts @@ -117,6 +117,8 @@ export enum MetricType { SUBSTRUCTURE = 'SUBSTRUCTURE', // SUPERSTRUCTURE superstructure SUPERSTRUCTURE = 'SUPERSTRUCTURE', + // BM 25 + BM25 = 'BM25', } // Index types @@ -279,6 +281,11 @@ export enum DataType { SparseFloatVector = 104, } +export enum FunctionType { + Unknown = 0, + BM25 = 1, +} + export const VectorDataTypes = [ DataType.BinaryVector, DataType.FloatVector, diff --git a/milvus/grpc/BaseClient.ts b/milvus/grpc/BaseClient.ts index 383f83a7..a3294dc8 100644 --- a/milvus/grpc/BaseClient.ts +++ b/milvus/grpc/BaseClient.ts @@ -67,15 +67,13 @@ export class BaseClient { protected schemaProto: Root; // The Milvus protobuf. protected milvusProto: Root; - // The milvus collection schema Type - protected collectionSchemaType: Type; - // The milvus field schema Type - protected fieldSchemaType: Type; + // milvus proto protected readonly protoInternalPath = { serviceName: 'milvus.proto.milvus.MilvusService', collectionSchema: 'milvus.proto.schema.CollectionSchema', fieldSchema: 'milvus.proto.schema.FieldSchema', + functionSchema: 'milvus.proto.schema.FunctionSchema', }; /** @@ -136,14 +134,6 @@ export class BaseClient { this.schemaProto = protobuf.loadSync(this.protoFilePath.schema); this.milvusProto = protobuf.loadSync(this.protoFilePath.milvus); - // Get the CollectionSchemaType and FieldSchemaType from the schemaProto object. - this.collectionSchemaType = this.schemaProto.lookupType( - this.protoInternalPath.collectionSchema - ); - this.fieldSchemaType = this.schemaProto.lookupType( - this.protoInternalPath.fieldSchema - ); - // options this.channelOptions = { // Milvus default max_receive_message_length is 100MB, but Milvus support change max_receive_message_length . diff --git a/milvus/grpc/Collection.ts b/milvus/grpc/Collection.ts index a1bf553d..b76d1700 100644 --- a/milvus/grpc/Collection.ts +++ b/milvus/grpc/Collection.ts @@ -145,15 +145,28 @@ export class Collection extends Database { validatePartitionNumbers(num_partitions); } + // Get the CollectionSchemaType and FieldSchemaType from the schemaProto object. + const schemaTypes = { + collectionSchemaType: this.schemaProto.lookupType( + this.protoInternalPath.collectionSchema + ), + fieldSchemaType: this.schemaProto.lookupType( + this.protoInternalPath.fieldSchema + ), + functionSchemaType: this.schemaProto.lookupType( + this.protoInternalPath.functionSchema + ), + }; + // Create the payload object with the collection_name, description, and fields. // it should follow CollectionSchema in schema.proto - const payload = formatCollectionSchema(data, this.fieldSchemaType); + const payload = formatCollectionSchema(data, schemaTypes); // Create the collectionParams object from the payload. - const collectionSchema = this.collectionSchemaType.create(payload); + const collectionSchema = schemaTypes.collectionSchemaType.create(payload); // Encode the collectionParams object to bytes. - const schemaBytes = this.collectionSchemaType + const schemaBytes = schemaTypes.collectionSchemaType .encode(collectionSchema) .finish(); diff --git a/milvus/grpc/Data.ts b/milvus/grpc/Data.ts index 298a5b62..737844f4 100644 --- a/milvus/grpc/Data.ts +++ b/milvus/grpc/Data.ts @@ -141,9 +141,11 @@ export class Data extends Collection { // Tip: The field data sequence needs to be set same as `collectionInfo.schema.fields`. // If primarykey is set `autoid = true`, you cannot insert the data. + // and if function field is set, you need to ignore the field value in the data. const fieldMap = new Map( collectionInfo.schema.fields .filter(v => !v.is_primary_key || !v.autoID) + .filter(v => !v.is_function_output) .map(v => [ v.name, { diff --git a/milvus/types/Collection.ts b/milvus/types/Collection.ts index eeba9927..3e470443 100644 --- a/milvus/types/Collection.ts +++ b/milvus/types/Collection.ts @@ -13,6 +13,7 @@ import { LoadState, DataTypeMap, ShowCollectionsType, + FunctionType, } from '../'; // returned from milvus @@ -21,7 +22,7 @@ export interface FieldSchema { index_params: KeyValuePair[]; fieldID: string | number; name: string; - is_primary_key?: boolean; + is_primary_key: boolean; description: string; data_type: keyof typeof DataType; autoID: boolean; @@ -29,9 +30,11 @@ export interface FieldSchema { element_type?: keyof typeof DataType; default_value?: number | string; dataType: DataType; - is_partition_key?: boolean; - is_dynamic?: boolean; - is_clustering_key?: boolean; + is_partition_key: boolean; + is_dynamic: boolean; + is_clustering_key: boolean; + is_function_output: boolean; + nullable: boolean; } export interface CollectionData { @@ -56,7 +59,7 @@ export interface ReplicaInfo { node_ids: string[]; } -export type TypeParam = string | number; +export type TypeParam = string | number | Record; export type TypeParamKey = 'dim' | 'max_length' | 'max_capacity'; // create collection @@ -67,6 +70,7 @@ export interface FieldType { element_type?: DataType | keyof typeof DataTypeMap; is_primary_key?: boolean; is_partition_key?: boolean; + is_function_output?: boolean; type_params?: { [key: string]: TypeParam; }; @@ -75,6 +79,9 @@ export interface FieldType { max_capacity?: TypeParam; max_length?: TypeParam; default_value?: number | string; + enable_match?: boolean; + tokenizer_params?: Record; + enable_tokenizer?: boolean; } export interface ShowCollectionsReq extends GrpcTimeOut { @@ -85,6 +92,15 @@ export interface ShowCollectionsReq extends GrpcTimeOut { export type Properties = Record; +export type Function = { + name: string; + description?: string; + type: FunctionType; + input_field_names: string[]; + output_field_names?: string[]; + params: Record; +}; + export interface BaseCreateCollectionReq extends GrpcTimeOut { // collection name collection_name: string; // required, collection name @@ -100,8 +116,9 @@ export interface BaseCreateCollectionReq extends GrpcTimeOut { partition_key_field?: string; // optional, partition key field enable_dynamic_field?: boolean; // optional, enable dynamic field, default is false enableDynamicField?: boolean; // optional, alias of enable_dynamic_field - properties?: Properties; - db_name?: string; + properties?: Properties; // optional, collection properties + db_name?: string; // optional, db name + functions?: Function[]; // optionals, doc-in/doc-out functions } export interface CreateCollectionWithFieldsReq extends BaseCreateCollectionReq { @@ -185,6 +202,7 @@ export interface CollectionSchema { enable_dynamic_field: boolean; autoID: boolean; fields: FieldSchema[]; + functions: Function[]; } export interface DescribeCollectionResponse extends TimeStamp { @@ -203,6 +221,7 @@ export interface DescribeCollectionResponse extends TimeStamp { shards_num: number; num_partitions?: string; // int64 db_name: string; + functions: Function[]; } export interface GetCompactionPlansResponse extends resStatusResponse { diff --git a/milvus/types/Data.ts b/milvus/types/Data.ts index d2434515..0be956c5 100644 --- a/milvus/types/Data.ts +++ b/milvus/types/Data.ts @@ -282,7 +282,7 @@ export interface SearchParam { group_by_field?: string; // group by field } -// old search api parameter type +// old search api parameter type, deprecated export interface SearchReq extends collectionNameReq { anns_field?: string; // your vector field name partition_names?: string[]; // partition names @@ -307,13 +307,18 @@ export interface SearchIteratorReq limit: number; } +export type SearchTextType = string | string[]; +export type SearchVectorType = VectorTypes | VectorTypes[]; +export type SearchDataType = SearchVectorType | SearchTextType; +export type SearchMultipleDataType = VectorTypes[] | SearchTextType[]; + // simplified search api parameter type export interface SearchSimpleReq extends collectionNameReq { partition_names?: string[]; // partition names - anns_field?: string; // your vector field name - data?: VectorTypes[] | VectorTypes; // vector to search - vector?: VectorTypes; // alias for data - vectors?: VectorTypes[]; // alias for data + anns_field?: string; // your vector field name,rquired if you are searching on multiple vector fields collection + data?: SearchDataType; // vector or text to search + vector?: VectorTypes; // alias for data, deprecated + vectors?: VectorTypes[]; // alias for data, deprecated output_fields?: string[]; limit?: number; // how many results you want topk?: number; // limit alias @@ -333,7 +338,7 @@ export type HybridSearchSingleReq = Pick< SearchParam, 'anns_field' | 'ignore_growing' | 'group_by_field' > & { - data: VectorTypes[] | VectorTypes; // vector to search + data: SearchDataType; // vector to search expr?: string; // filter expression params?: keyValueObj; // extra search parameters transformers?: OutputTransformers; // provide custom data transformer for specific data type like bf16 or f16 vectors diff --git a/milvus/utils/Bytes.ts b/milvus/utils/Bytes.ts index 66f354b6..3e63f349 100644 --- a/milvus/utils/Bytes.ts +++ b/milvus/utils/Bytes.ts @@ -5,12 +5,13 @@ import { BinaryVector, SparseFloatVector, DataType, - VectorTypes, + SearchMultipleDataType, Float16Vector, SparseVectorCSR, SparseVectorCOO, BFloat16Vector, SparseVectorArray, + FieldSchema, } from '..'; /** @@ -250,41 +251,49 @@ export const bytesToSparseRow = (bufferData: Buffer): SparseFloatVector => { * This function builds a placeholder group in bytes format for Milvus. * * @param {Root} milvusProto - The root object of the Milvus protocol. - * @param {VectorTypes[]} vectors - An array of search vectors. + * @param {SearchMultipleDataType[]} data - An array of search vectors. * @param {DataType} vectorDataType - The data type of the vectors. * * @returns {Uint8Array} The placeholder group in bytes format. */ export const buildPlaceholderGroupBytes = ( milvusProto: Root, - vectors: VectorTypes[], - vectorDataType: DataType + data: SearchMultipleDataType, + field: FieldSchema ) => { + const { dataType, is_function_output } = field; // create placeholder_group value let bytes; - // parse vectors to bytes - switch (vectorDataType) { - case DataType.FloatVector: - bytes = vectors.map(v => f32ArrayToF32Bytes(v as FloatVector)); - break; - case DataType.BinaryVector: - bytes = vectors.map(v => f32ArrayToBinaryBytes(v as BinaryVector)); - break; - case DataType.BFloat16Vector: - bytes = vectors.map(v => - Array.isArray(v) ? f32ArrayToBf16Bytes(v as BFloat16Vector) : v - ); - break; - case DataType.Float16Vector: - bytes = vectors.map(v => - Array.isArray(v) ? f32ArrayToF16Bytes(v as Float16Vector) : v - ); - break; - case DataType.SparseFloatVector: - bytes = vectors.map(v => sparseToBytes(v as SparseFloatVector)); - break; + if (is_function_output) { + // parse text to bytes + bytes = data.map(d => new TextEncoder().encode(String(d))); + } else { + // parse vectors to bytes + switch (dataType) { + case DataType.FloatVector: + bytes = data.map(v => f32ArrayToF32Bytes(v as FloatVector)); + break; + case DataType.BinaryVector: + bytes = data.map(v => f32ArrayToBinaryBytes(v as BinaryVector)); + break; + case DataType.BFloat16Vector: + bytes = data.map(v => + Array.isArray(v) ? f32ArrayToBf16Bytes(v as BFloat16Vector) : v + ); + break; + case DataType.Float16Vector: + bytes = data.map(v => + Array.isArray(v) ? f32ArrayToF16Bytes(v as Float16Vector) : v + ); + break; + case DataType.SparseFloatVector: + bytes = data.map(v => sparseToBytes(v as SparseFloatVector)); + + break; + } } + // create placeholder_group const PlaceholderGroup = milvusProto.lookupType( 'milvus.proto.common.PlaceholderGroup' @@ -295,7 +304,7 @@ export const buildPlaceholderGroupBytes = ( placeholders: [ { tag: '$0', - type: vectorDataType, + type: is_function_output ? DataType.VarChar : dataType, values: bytes, }, ], diff --git a/milvus/utils/Format.ts b/milvus/utils/Format.ts index b5e8aa47..967a121b 100644 --- a/milvus/utils/Format.ts +++ b/milvus/utils/Format.ts @@ -41,6 +41,10 @@ import { f32ArrayToF16Bytes, bf16BytesToF32Array, f16BytesToF32Array, + TypeParam, + SearchDataType, + FieldSchema, + SearchMultipleDataType, } from '../'; /** @@ -64,12 +68,18 @@ export const formatKeyValueData = (data: KeyValuePair[], keys: string[]) => { * @param data Object * @returns {KeyValuePair[]} */ -export const parseToKeyValue = (data?: { - [x: string]: any; -}): KeyValuePair[] => { +export const parseToKeyValue = ( + data?: { + [x: string]: any; + }, + valueToString?: boolean +): KeyValuePair[] => { return data ? Object.keys(data).reduce( - (pre: any[], cur: string) => [...pre, { key: cur, value: data[cur] }], + (pre: any[], cur: string) => [ + ...pre, + { key: cur, value: valueToString ? String(data[cur]) : data[cur] }, + ], [] ) : []; @@ -188,21 +198,34 @@ export const formatAddress = (address: string) => { */ export const assignTypeParams = ( field: FieldType, - typeParamKeys: string[] = ['dim', 'max_length', 'max_capacity'] + typeParamKeys: string[] = [ + 'dim', + 'max_length', + 'max_capacity', + 'enable_match', + 'enable_tokenizer', + 'tokenizer_params', + ] ) => { let newField = cloneObj(field); typeParamKeys.forEach(key => { if (newField.hasOwnProperty(key)) { // if the property exists in the field object, assign it to the type_params object newField.type_params = newField.type_params || {}; - newField.type_params[key] = String(newField[key as keyof FieldType]); + newField.type_params[key] = + typeof newField[key as keyof FieldType] !== 'object' + ? String(newField[key as keyof FieldType] ?? '') + : (newField[key as keyof FieldType] as TypeParam); // delete the property from the field object delete newField[key as keyof FieldType]; } if (newField.type_params && newField.type_params[key]) { - // if the property already exists in the type_params object, convert it to a string - newField.type_params[key] = String(newField.type_params[key]); + // if the property already exists in the type_params object, convert it to a string, + newField.type_params[key] = + typeof newField.type_params[key] !== 'object' + ? String(newField.type_params[key]) + : newField.type_params[key]; } }); return newField; @@ -266,7 +289,7 @@ export const convertToDataType = ( throw new Error(ERROR_REASONS.FIELD_TYPE_IS_NOT_SUPPORT); }; -/** +/**dd * Creates a deep copy of the provided object using JSON.parse and JSON.stringify. * Note that this function is not efficient and may cause performance issues if used with large or complex objects. It also does not handle cases where the object being cloned contains functions or prototype methods. * @@ -287,7 +310,7 @@ export const cloneObj = (obj: T): T => { */ export const formatCollectionSchema = ( data: CreateCollectionReq, - fieldSchemaType: Type + schemaTypes: Record ): { [k: string]: any } => { const { collection_name, @@ -295,6 +318,7 @@ export const formatCollectionSchema = ( enable_dynamic_field, enableDynamicField, partition_key_field, + functions, } = data; let fields = (data as CreateCollectionWithFieldsReq).fields; @@ -309,23 +333,31 @@ export const formatCollectionSchema = ( enableDynamicField: !!enableDynamicField || !!enable_dynamic_field, fields: fields.map(field => { // Assign the typeParams property to the result of parseToKeyValue(type_params). - const { type_params, ...rest } = assignTypeParams(field); + const { + type_params, + data_type, + element_type, + is_function_output, + is_partition_key, + is_primary_key, + ...rest + } = assignTypeParams(field); const dataType = convertToDataType(field.data_type); const createObj: any = { ...rest, typeParams: parseToKeyValue(type_params), + data_type, // compatibility with old version dataType, - isPrimaryKey: !!field.is_primary_key, + isPrimaryKey: !!is_primary_key, isPartitionKey: - !!field.is_partition_key || field.name === partition_key_field, + !!is_partition_key || field.name === partition_key_field, + isFunctionOutput: !!is_function_output, }; // if element type exist and - if ( - dataType === DataType.Array && - typeof field.element_type !== 'undefined' - ) { - createObj.elementType = convertToDataType(field.element_type); + if (dataType === DataType.Array && typeof element_type !== 'undefined') { + createObj.elementType = convertToDataType(element_type); + createObj.element_type = element_type; // compatibility with old version } if (typeof field.default_value !== 'undefined') { @@ -335,9 +367,23 @@ export const formatCollectionSchema = ( [dataKey]: field.default_value, }; } - return fieldSchemaType.create(createObj); + return schemaTypes.fieldSchemaType.create(createObj); }), - }; + functions: [], + } as any; + + // if functions is set, parse its params to key-value pairs, and delete inputs and outputs + if (functions) { + payload.functions = functions.map((func: any) => { + const { input_field_names, output_field_names, ...rest } = func; + return schemaTypes.functionSchemaType.create({ + ...rest, + inputFieldNames: input_field_names, + outputFieldNames: output_field_names, + params: parseToKeyValue(func.params, true), + }); + }); + } return payload; }; @@ -715,7 +761,9 @@ export const buildSearchRequest = ( // merge single request with hybrid request req = Object.assign(cloneObj(data), singleReq); } else { - // if it is not hybrid search, and we have built one request, skip + // if it is not hybrid search, and we have built one request + // or user has specified an anns_field to search and is not matching + // skip const skip = requests.length === 1 || (typeof req.anns_field !== 'undefined' && req.anns_field !== name); @@ -724,29 +772,29 @@ export const buildSearchRequest = ( } } - // get search vectors - let searchingVector: VectorTypes | VectorTypes[] = isHybridSearch + // get search data + let searchData: SearchDataType | SearchMultipleDataType = isHybridSearch ? req.data! : searchReq.vectors || searchSimpleReq.vectors || searchSimpleReq.vector || searchSimpleReq.data; - // format searching vector - searchingVector = formatSearchVector(searchingVector, field.dataType!); + // format searching data + searchData = formatSearchData(searchData, field); // create search request requests.push({ collection_name: req.collection_name, partition_names: req.partition_names || [], output_fields: req.output_fields || default_output_fields, - nq: searchReq.nq || searchingVector.length, + nq: searchReq.nq || searchData.length, dsl: searchReq.expr || searchSimpleReq.filter || '', dsl_type: DslType.BoolExprV1, placeholder_group: buildPlaceholderGroupBytes( milvusProto, - searchingVector as VectorTypes[], - field.dataType! + searchData as VectorTypes[], + field ), search_params: parseToKeyValue( searchReq.search_params || buildSearchParams(req, name) @@ -882,28 +930,36 @@ export const formatSearchResult = ( /** * Formats the search vector to match a specific data type. - * @param {VectorTypes | VectorTypes[]} searchVector - The search vector or array of vectors to be formatted. + * @param {SearchDataType[]} searchVector - The search vector or array of vectors to be formatted. * @param {DataType} dataType - The specified data type. * @returns {VectorTypes[]} The formatted search vector or array of vectors. */ -export const formatSearchVector = ( - searchVector: VectorTypes | VectorTypes[], - dataType: DataType -): VectorTypes[] => { +export const formatSearchData = ( + searchData: SearchDataType | SearchMultipleDataType, + field: FieldSchema +): SearchMultipleDataType => { + const { dataType, is_function_output } = field; + + if (is_function_output) { + return ( + Array.isArray(searchData) ? searchData : [searchData] + ) as SearchMultipleDataType; + } + switch (dataType) { case DataType.FloatVector: case DataType.BinaryVector: case DataType.Float16Vector: case DataType.BFloat16Vector: - if (!Array.isArray(searchVector)) { - return [searchVector] as VectorTypes[]; + if (!Array.isArray(searchData)) { + return [searchData] as VectorTypes[]; } case DataType.SparseFloatVector: - const type = getSparseFloatVectorType(searchVector as SparseVectorArray); + const type = getSparseFloatVectorType(searchData as SparseVectorArray); if (type !== 'unknown') { - return [searchVector] as VectorTypes[]; + return [searchData] as VectorTypes[]; } default: - return searchVector as VectorTypes[]; + return searchData as VectorTypes[]; } }; diff --git a/milvus/utils/Grpc.ts b/milvus/utils/Grpc.ts index 28459b89..7a450c92 100644 --- a/milvus/utils/Grpc.ts +++ b/milvus/utils/Grpc.ts @@ -200,7 +200,7 @@ export const getRetryInterceptor = ({ logger.debug( `\x1b[32m[Response(${ Date.now() - startTime.getTime() - }ms)]\x1b[0m\x1b[2m${clientId}\x1b[0m>${dbname}>\x1b[1m${methodName}\x1b[0m: ${msg}` + }ms)]\x1b[0m\x1b[2m${clientId}\x1b[0m>${dbname}>\x1b[1m${methodName}\x1b[0m: ${string}` ); savedMessageNext(savedReceiveMessage); @@ -217,7 +217,7 @@ export const getRetryInterceptor = ({ const msg = string.length > 2048 ? string.slice(0, 2048) + '...' : string; logger.debug( - `\x1b[34m[Request]\x1b[0m${clientId}>${dbname}>\x1b[1m${methodName}(${timeoutInSeconds})\x1b[0m: ${msg}` + `\x1b[34m[Request]\x1b[0m${clientId}>${dbname}>\x1b[1m${methodName}(${timeoutInSeconds})\x1b[0m: ${string}` ); savedSendMessage = message; next(message); diff --git a/package.json b/package.json index 3b9ec67a..cc7931a2 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@zilliz/milvus2-sdk-node", "author": "ued@zilliz.com", - "milvusVersion": "master-20240911-42eef490-amd64", + "milvusVersion": "master-20241024-f78f6112-amd64", "version": "2.4.9", "main": "dist/milvus", "files": [ diff --git a/proto b/proto index 8f8ca678..85ccff4d 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 8f8ca67816cd2fee2b4c72f30c0ede66a7935087 +Subproject commit 85ccff4d57fe9c510c88ee4eaf1ba33ef4ef1188 diff --git a/test/grpc/Functions.spec.ts b/test/grpc/Functions.spec.ts new file mode 100644 index 00000000..f68c1301 --- /dev/null +++ b/test/grpc/Functions.spec.ts @@ -0,0 +1,242 @@ +import { + MilvusClient, + DataType, + ErrorCode, + MetricType, + ConsistencyLevelEnum, + FunctionType, +} from '../../milvus'; +import { + IP, + genCollectionParams, + GENERATE_NAME, + generateInsertData, + dynamicFields, +} from '../tools'; + +const milvusClient = new MilvusClient({ address: IP, logLevel: 'debug' }); +const COLLECTION = GENERATE_NAME(); +const dbParam = { + db_name: 'Functions', +}; +const numPartitions = 3; + +// create +const createCollectionParams = genCollectionParams({ + collectionName: COLLECTION, + dim: [4], + vectorType: [DataType.FloatVector], + autoID: false, + partitionKeyEnabled: true, + numPartitions, + enableDynamic: true, + fields: [ + { + name: 'text', + description: 'text field', + data_type: DataType.VarChar, + max_length: 20, + is_partition_key: false, + enable_tokenizer: true, + }, + { + name: 'sparse', + description: 'sparse field', + data_type: DataType.SparseFloatVector, + is_function_output: true, + }, + { + name: 'sparse2', + description: 'sparse field2', + data_type: DataType.SparseFloatVector, + is_function_output: true, + }, + ], + functions: [ + { + name: 'bm25f1', + description: 'bm25 function', + type: FunctionType.BM25, + input_field_names: ['text'], + output_field_names: ['sparse'], + params: {}, + }, + { + name: 'bm25f2', + description: 'bm25 function', + type: FunctionType.BM25, + input_field_names: ['text'], + output_field_names: ['sparse2'], + params: {}, + }, + ], +}); + +describe(`Functions schema API`, () => { + beforeAll(async () => { + // create db and use db + await milvusClient.createDatabase(dbParam); + await milvusClient.use(dbParam); + }); + afterAll(async () => { + await milvusClient.dropCollection({ + collection_name: COLLECTION, + }); + await milvusClient.dropDatabase(dbParam); + }); + + it(`Create schema with function collection should success`, async () => { + const create = await milvusClient.createCollection(createCollectionParams); + + expect(create.error_code).toEqual(ErrorCode.SUCCESS); + + // describe + const describe = await milvusClient.describeCollection({ + collection_name: COLLECTION, + }); + // expect the 'sparse' field to be created + expect(describe.schema.fields.length).toEqual( + createCollectionParams.fields.length + ); + // extract the 'sparse' field + const sparse = describe.schema.fields.find( + field => field.is_function_output + ); + // expect the 'sparse' field's name to be 'sparse' + expect(sparse!.name).toEqual('sparse'); + + // expect functions are in the schema + expect(describe.schema.functions.length).toEqual(2); + expect(describe.schema.functions[0].name).toEqual('bm25f1'); + expect(describe.schema.functions[0].input_field_names).toEqual(['text']); + expect(describe.schema.functions[0].output_field_names).toEqual(['sparse']); + expect(describe.schema.functions[0].type).toEqual('BM25'); + expect(describe.schema.functions[1].name).toEqual('bm25f2'); + expect(describe.schema.functions[1].input_field_names).toEqual(['text']); + expect(describe.schema.functions[1].output_field_names).toEqual([ + 'sparse2', + ]); + expect(describe.schema.functions[1].type).toEqual('BM25'); + }); + + it(`Insert data with function field should success`, async () => { + const data = generateInsertData( + [...createCollectionParams.fields, ...dynamicFields], + 10 + ); + + const insert = await milvusClient.insert({ + collection_name: COLLECTION, + fields_data: data, + }); + + expect(insert.status.error_code).toEqual(ErrorCode.SUCCESS); + }); + + it(`Create index on function output field should success`, async () => { + // create index + const createVectorIndex = await milvusClient.createIndex({ + collection_name: COLLECTION, + index_name: 't', + field_name: 'vector', + index_type: 'HNSW', + metric_type: MetricType.COSINE, + params: { M: 4, efConstruction: 8 }, + }); + + const createIndex = await milvusClient.createIndex({ + collection_name: COLLECTION, + index_name: 't2', + field_name: 'sparse', + index_type: 'SPARSE_INVERTED_INDEX', + metric_type: 'BM25', + params: { drop_ratio_build: 0.3, bm25_k1: 1.25, bm25_b: 0.8 }, + }); + + const createIndex2 = await milvusClient.createIndex({ + collection_name: COLLECTION, + index_name: 't3', + field_name: 'sparse2', + index_type: 'SPARSE_INVERTED_INDEX', + metric_type: 'BM25', + params: { drop_ratio_build: 0.3, bm25_k1: 1.25, bm25_b: 0.8 }, + }); + + expect(createVectorIndex.error_code).toEqual(ErrorCode.SUCCESS); + expect(createIndex.error_code).toEqual(ErrorCode.SUCCESS); + expect(createIndex2.error_code).toEqual(ErrorCode.SUCCESS); + + // load + const load = await milvusClient.loadCollection({ + collection_name: COLLECTION, + }); + + expect(load.error_code).toEqual(ErrorCode.SUCCESS); + }); + + it(`query with function output field should success`, async () => { + // query + const query = await milvusClient.query({ + collection_name: COLLECTION, + limit: 10, + expr: 'id > 0', + output_fields: ['vector', 'id', 'text', 'sparse', 'sparse2'], + consistency_level: ConsistencyLevelEnum.Strong, + }); + + expect(query.status.error_code).toEqual(ErrorCode.SUCCESS); + expect(query.data.length).toEqual(10); + // data should have 'sparse' field + expect(query.data[0].hasOwnProperty('sparse')).toBeTruthy(); + // data should have 'sparse2' field + expect(query.data[0].hasOwnProperty('sparse2')).toBeTruthy(); + }); + + it(`search with varchar should success`, async () => { + // search nq = 1 + const search = await milvusClient.search({ + collection_name: COLLECTION, + limit: 10, + data: 'apple', + anns_field: 'sparse', + output_fields: ['*'], + params: { drop_ratio_search: 0.6 }, + consistency_level: ConsistencyLevelEnum.Strong, + }); + + expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); + + // nq > 1 + const search2 = await milvusClient.search({ + collection_name: COLLECTION, + limit: 10, + data: ['apple', 'banana'], + anns_field: 'sparse', + output_fields: ['*'], + params: { drop_ratio_search: 0.6 }, + consistency_level: ConsistencyLevelEnum.Strong, + }); + + expect(search2.status.error_code).toEqual(ErrorCode.SUCCESS); + + // multiple search + const search3 = await milvusClient.search({ + collection_name: COLLECTION, + limit: 10, + data: [ + { + data: 'apple', + anns_field: 'sparse', + params: { nprobe: 2 }, + }, + { + data: [1, 2, 3, 4], + anns_field: 'vector', + }, + ], + consistency_level: ConsistencyLevelEnum.Strong, + }); + + expect(search3.status.error_code).toEqual(ErrorCode.SUCCESS); + }); +}); diff --git a/test/tools/collection.ts b/test/tools/collection.ts index 7740c71b..b0487d84 100644 --- a/test/tools/collection.ts +++ b/test/tools/collection.ts @@ -1,5 +1,10 @@ -import { DataType, ConsistencyLevelEnum } from '../../milvus'; -import { VECTOR_FIELD_NAME, MAX_CAPACITY, MAX_LENGTH } from './const'; +import { + DataType, + ConsistencyLevelEnum, + FunctionType, + Function, +} from '../../milvus'; +import { MAX_CAPACITY, MAX_LENGTH } from './const'; import { GENERATE_VECTOR_NAME } from './'; export const dynamicFields = [ @@ -41,6 +46,7 @@ export const genCollectionParams = (data: { enableDynamic?: boolean; maxCapacity?: number; idType?: DataType; + functions?: Function[]; }) => { const { collectionName, @@ -53,6 +59,7 @@ export const genCollectionParams = (data: { enableDynamic = false, maxCapacity, idType = DataType.Int64, + functions, } = data; const vectorFields = vectorType.map((type, i) => { @@ -115,6 +122,7 @@ export const genCollectionParams = (data: { data_type: DataType.VarChar, max_length: MAX_LENGTH, is_partition_key: partitionKeyEnabled, + enable_tokenizer: true, }, { name: 'json', @@ -152,5 +160,9 @@ export const genCollectionParams = (data: { params.num_partitions = numPartitions; } + if (functions && functions?.length > 0) { + params.functions = functions; + } + return params; }; diff --git a/test/tools/data.ts b/test/tools/data.ts index f1c4dfb2..b0c8e07b 100644 --- a/test/tools/data.ts +++ b/test/tools/data.ts @@ -295,6 +295,11 @@ export const generateInsertData = ( continue; } + // skip fields with is_function_output = true + if (field.is_function_output) { + continue; + } + // Parameters used to generate all types of data const genDataParams = { dim: Number(field.dim || (field.type_params && field.type_params.dim)), diff --git a/test/utils/Format.spec.ts b/test/utils/Format.spec.ts index e98d34d0..389715ee 100644 --- a/test/utils/Format.spec.ts +++ b/test/utils/Format.spec.ts @@ -26,8 +26,9 @@ import { buildFieldData, formatSearchResult, Field, - formatSearchVector, + formatSearchData, buildSearchRequest, + FieldSchema, } from '../../milvus'; describe('utils/format', () => { @@ -75,6 +76,15 @@ describe('utils/format', () => { ]); }); + it(`should convert {row_count:4, b: 3} t0 [{key:"row_count",value:'4'}, {key: "b", value: '4'}]`, () => { + const testValue = { row_count: '4', b: 3 }; + const res = parseToKeyValue(testValue, true); + expect(res).toMatchObject([ + { key: 'row_count', value: '4' }, + { key: 'b', value: '3' }, + ]); + }); + it(`should convert [{key:"row_count",value:4}] to {row_count:4}`, () => { const testValue = 3.1231241241234124124; const res = formatNumberPrecision(testValue, 3); @@ -166,12 +176,15 @@ describe('utils/format', () => { expect(methodName).toBe('123'); }); - it('should assign properties with keys `dim` or `max_length` to the `type_params` object and delete them from the `field` object', () => { + it('should assign properties with keys `dim` or `max_length` to the `type_params`, `enable_match`, `tokenizer_params`, `enable_tokenizer` object and delete them from the `field` object', () => { const field = { name: 'vector', data_type: 'BinaryVector', dim: 128, max_length: 100, + enable_match: true, + tokenizer_params: { key: 'value' }, + enable_tokenizer: true, } as FieldType; const expectedOutput = { name: 'vector', @@ -179,6 +192,9 @@ describe('utils/format', () => { type_params: { dim: '128', max_length: '100', + enable_match: 'true', + tokenizer_params: { key: 'value' }, + enable_tokenizer: 'true', }, }; expect(assignTypeParams(field)).toEqual(expectedOutput); @@ -261,7 +277,7 @@ describe('utils/format', () => { }, { name: 'testField2', - data_type: DataType.FloatVector, + data_type: 'FloatVector', is_primary_key: false, description: 'Test VECTOR field', dim: 64, @@ -274,7 +290,7 @@ describe('utils/format', () => { element_type: DataType.Int64, }, ], - }; + } as any; const schemaProtoPath = path.resolve( __dirname, @@ -285,6 +301,9 @@ describe('utils/format', () => { const fieldSchemaType = schemaProto.lookupType( 'milvus.proto.schema.FieldSchema' ); + const functionSchemaType = schemaProto.lookupType( + 'milvus.proto.schema.FunctionSchema' + ); const expectedResult = { name: 'testCollection', @@ -295,50 +314,45 @@ describe('utils/format', () => { typeParams: [], indexParams: [], name: 'testField1', - data_type: 5, - is_primary_key: true, description: 'Test PRIMARY KEY field', + data_type: 5, dataType: 5, isPrimaryKey: true, isPartitionKey: false, + isFunctionOutput: false, }, { - typeParams: [ - { - key: 'dim', - value: '64', - }, - ], + typeParams: [{ key: 'dim', value: '64' }], indexParams: [], name: 'testField2', - data_type: 101, - is_primary_key: false, description: 'Test VECTOR field', + data_type: 'FloatVector', dataType: 101, isPrimaryKey: false, isPartitionKey: false, + isFunctionOutput: false, }, { - typeParams: [ - { - key: 'max_capacity', - value: '64', - }, - ], + typeParams: [{ key: 'max_capacity', value: '64' }], indexParams: [], name: 'arrayField', - data_type: 22, description: 'Test Array field', - element_type: 5, + data_type: 22, dataType: 22, isPrimaryKey: false, isPartitionKey: false, + isFunctionOutput: false, elementType: 5, + element_type: 5, }, ], + functions: [], }; - const payload = formatCollectionSchema(data, fieldSchemaType); + const payload = formatCollectionSchema(data, { + fieldSchemaType, + functionSchemaType, + }); expect(payload).toEqual(expectedResult); }); @@ -377,6 +391,11 @@ describe('utils/format', () => { dataType: 101, autoID: false, state: 'created', + is_dynamic: false, + is_clustering_key: false, + is_function_output: false, + nullable: false, + is_partition_key: false, }, { fieldID: '2', @@ -389,12 +408,18 @@ describe('utils/format', () => { dataType: 5, autoID: true, state: 'created', + is_dynamic: false, + is_clustering_key: false, + is_function_output: false, + nullable: false, + is_partition_key: false, }, ], name: 'collection_v8mt0v7x', description: '', enable_dynamic_field: false, autoID: false, + functions: [], }, shards_num: 1, start_positions: [], @@ -405,6 +430,7 @@ describe('utils/format', () => { num_partitions: '0', collection_name: 'test', db_name: '', + functions: [], }; const formatted = formatDescribedCol(response); @@ -597,19 +623,29 @@ describe('utils/format', () => { it('should format search vector correctly', () => { // float vector const floatVector = [1, 2, 3]; - const formattedVector = formatSearchVector( - floatVector, - DataType.FloatVector - ); + const formattedVector = formatSearchData(floatVector, { + dataType: DataType.FloatVector, + } as FieldSchema); expect(formattedVector).toEqual([floatVector]); const floatVectors = [ [1, 2, 3], [4, 5, 6], ]; - expect(formatSearchVector(floatVectors, DataType.FloatVector)).toEqual( - floatVectors - ); + expect( + formatSearchData(floatVectors, { + dataType: DataType.FloatVector, + } as FieldSchema) + ).toEqual(floatVectors); + + // varchar + const varcharVector = 'hello world'; + expect( + formatSearchData(varcharVector, { + dataType: DataType.SparseFloatVector, + is_function_output: true, + } as FieldSchema) + ).toEqual([varcharVector]); }); it('should format sparse vectors correctly', () => { @@ -618,10 +654,9 @@ describe('utils/format', () => { { index: 1, value: 2 }, { index: 3, value: 4 }, ]; - const formattedSparseCooVector = formatSearchVector( - sparseCooVector, - DataType.SparseFloatVector - ); + const formattedSparseCooVector = formatSearchData(sparseCooVector, { + dataType: DataType.SparseFloatVector, + } as FieldSchema); expect(formattedSparseCooVector).toEqual([sparseCooVector]); // sparse csr vector @@ -629,10 +664,9 @@ describe('utils/format', () => { indices: [1, 3], values: [2, 4], }; - const formattedSparseCsrVector = formatSearchVector( - sparseCsrVector, - DataType.SparseFloatVector - ); + const formattedSparseCsrVector = formatSearchData(sparseCsrVector, { + dataType: DataType.SparseFloatVector, + } as FieldSchema); expect(formattedSparseCsrVector).toEqual([sparseCsrVector]); const sparseCsrVectors = [ @@ -645,46 +679,41 @@ describe('utils/format', () => { values: [3, 5], }, ]; - const formattedSparseCsrVectors = formatSearchVector( - sparseCsrVectors, - DataType.SparseFloatVector - ); + const formattedSparseCsrVectors = formatSearchData(sparseCsrVectors, { + dataType: DataType.SparseFloatVector, + } as FieldSchema); expect(formattedSparseCsrVectors).toEqual(sparseCsrVectors); // sparse array vector const sparseArrayVector = [0.1, 0.2, 0.3]; - const formattedSparseArrayVector = formatSearchVector( - sparseArrayVector, - DataType.SparseFloatVector - ); + const formattedSparseArrayVector = formatSearchData(sparseArrayVector, { + dataType: DataType.SparseFloatVector, + } as FieldSchema); expect(formattedSparseArrayVector).toEqual([sparseArrayVector]); const sparseArrayVectors = [ [0.1, 0.2, 0.3], [0.4, 0.5, 0.6], ]; - const formattedSparseArrayVectors = formatSearchVector( - sparseArrayVectors, - DataType.SparseFloatVector - ); + const formattedSparseArrayVectors = formatSearchData(sparseArrayVectors, { + dataType: DataType.SparseFloatVector, + } as FieldSchema); expect(formattedSparseArrayVectors).toEqual(sparseArrayVectors); // sparse dict vector const sparseDictVector = { 1: 2, 3: 4 }; - const formattedSparseDictVector = formatSearchVector( - sparseDictVector, - DataType.SparseFloatVector - ); + const formattedSparseDictVector = formatSearchData(sparseDictVector, { + dataType: DataType.SparseFloatVector, + } as FieldSchema); expect(formattedSparseDictVector).toEqual([sparseDictVector]); const sparseDictVectors = [ { 1: 2, 3: 4 }, { 1: 2, 3: 4 }, ]; - const formattedSparseDictVectors = formatSearchVector( - sparseDictVectors, - DataType.SparseFloatVector - ); + const formattedSparseDictVectors = formatSearchData(sparseDictVectors, { + dataType: DataType.SparseFloatVector, + } as FieldSchema); expect(formattedSparseDictVectors).toEqual(sparseDictVectors); }); @@ -866,7 +895,6 @@ describe('utils/format', () => { describeCollectionResponse, milvusProto ); - console.dir(searchRequest, { depth: null }); expect(searchRequest.isHybridSearch).toEqual(true); expect(searchRequest.request.collection_name).toEqual('test'); expect(searchRequest.request.output_fields).toEqual(['vector', 'vector1']);