diff --git a/milvus/utils/Format.ts b/milvus/utils/Format.ts index 53da147a..b5e8aa47 100644 --- a/milvus/utils/Format.ts +++ b/milvus/utils/Format.ts @@ -693,62 +693,68 @@ export const buildSearchRequest = ( searchHybridReq.data[0].anns_field ); + // output fields(reference fields) + const default_output_fields: string[] = ['*']; + // Iterate through collection fields, create search request for (let i = 0; i < collectionInfo.schema.fields.length; i++) { const field = collectionInfo.schema.fields[i]; - const { name } = field; - - let req: SearchSimpleReq | (HybridSearchReq & HybridSearchSingleReq) = - data as SearchSimpleReq; - - if (isHybridSearch) { - const singleReq = searchHybridReq.data.find(d => d.anns_field === name); - // if it is hybrid search and no request target is not found, skip - if (!singleReq) { - continue; - } - // merge single request with hybrid request - req = Object.assign(cloneObj(data), singleReq); - } else { - // if it is not hybrid search, and we have built one request, skip - const skip = - requests.length === 1 || - (typeof req.anns_field !== 'undefined' && req.anns_field !== name); - if (skip) { - continue; + const { name, dataType } = field; + + // if field type is vector, build the request + if (isVectorType(dataType)) { + let req: SearchSimpleReq | (HybridSearchReq & HybridSearchSingleReq) = + data as SearchSimpleReq; + + if (isHybridSearch) { + const singleReq = searchHybridReq.data.find(d => d.anns_field === name); + // if it is hybrid search and no request target is not found, skip + if (!singleReq) { + continue; + } + // merge single request with hybrid request + req = Object.assign(cloneObj(data), singleReq); + } else { + // if it is not hybrid search, and we have built one request, skip + const skip = + requests.length === 1 || + (typeof req.anns_field !== 'undefined' && req.anns_field !== name); + if (skip) { + continue; + } } - } - // get search vectors - let searchingVector: VectorTypes | VectorTypes[] = isHybridSearch - ? req.data! - : searchReq.vectors || - searchSimpleReq.vectors || - searchSimpleReq.vector || - searchSimpleReq.data; - - // format searching vector - searchingVector = formatSearchVector(searchingVector, field.dataType!); - - // create search request - requests.push({ - collection_name: req.collection_name, - partition_names: req.partition_names || [], - output_fields: req.output_fields || ['*'], - nq: searchReq.nq || searchingVector.length, - dsl: searchReq.expr || searchSimpleReq.filter || '', - dsl_type: DslType.BoolExprV1, - placeholder_group: buildPlaceholderGroupBytes( - milvusProto, - searchingVector as VectorTypes[], - field.dataType! - ), - search_params: parseToKeyValue( - searchReq.search_params || buildSearchParams(req, name) - ), - consistency_level: - req.consistency_level || (collectionInfo.consistency_level as any), - }); + // get search vectors + let searchingVector: VectorTypes | VectorTypes[] = isHybridSearch + ? req.data! + : searchReq.vectors || + searchSimpleReq.vectors || + searchSimpleReq.vector || + searchSimpleReq.data; + + // format searching vector + searchingVector = formatSearchVector(searchingVector, field.dataType!); + + // create search request + requests.push({ + collection_name: req.collection_name, + partition_names: req.partition_names || [], + output_fields: req.output_fields || default_output_fields, + nq: searchReq.nq || searchingVector.length, + dsl: searchReq.expr || searchSimpleReq.filter || '', + dsl_type: DslType.BoolExprV1, + placeholder_group: buildPlaceholderGroupBytes( + milvusProto, + searchingVector as VectorTypes[], + field.dataType! + ), + search_params: parseToKeyValue( + searchReq.search_params || buildSearchParams(req, name) + ), + consistency_level: + req.consistency_level || (collectionInfo.consistency_level as any), + }); + } } /** diff --git a/package.json b/package.json index d4fc22c5..5f341bec 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "name": "@zilliz/milvus2-sdk-node", "author": "ued@zilliz.com", "version": "2.4.8", - "milvusVersion": "v2.4.10", + "milvusVersion": "v2.4.11", "main": "dist/milvus", "files": [ "dist" diff --git a/test/utils/Format.spec.ts b/test/utils/Format.spec.ts index b0d086cc..e98d34d0 100644 --- a/test/utils/Format.spec.ts +++ b/test/utils/Format.spec.ts @@ -27,6 +27,7 @@ import { formatSearchResult, Field, formatSearchVector, + buildSearchRequest, } from '../../milvus'; describe('utils/format', () => { @@ -611,77 +612,291 @@ describe('utils/format', () => { ); }); - // sparse coo vector - const sparseCooVector = [ - { index: 1, value: 2 }, - { index: 3, value: 4 }, - ]; - const formattedSparseCooVector = formatSearchVector( - sparseCooVector, - DataType.SparseFloatVector - ); - expect(formattedSparseCooVector).toEqual([sparseCooVector]); - - // sparse csr vector - const sparseCsrVector = { - indices: [1, 3], - values: [2, 4], - }; - const formattedSparseCsrVector = formatSearchVector( - sparseCsrVector, - DataType.SparseFloatVector - ); - expect(formattedSparseCsrVector).toEqual([sparseCsrVector]); - - const sparseCsrVectors = [ - { + it('should format sparse vectors correctly', () => { + // sparse coo vector + const sparseCooVector = [ + { index: 1, value: 2 }, + { index: 3, value: 4 }, + ]; + const formattedSparseCooVector = formatSearchVector( + sparseCooVector, + DataType.SparseFloatVector + ); + expect(formattedSparseCooVector).toEqual([sparseCooVector]); + + // sparse csr vector + const sparseCsrVector = { indices: [1, 3], values: [2, 4], - }, - { - indices: [2, 4], - values: [3, 5], - }, - ]; - const formattedSparseCsrVectors = formatSearchVector( - sparseCsrVectors, - DataType.SparseFloatVector - ); - expect(formattedSparseCsrVectors).toEqual(sparseCsrVectors); - - // sparse array vector - const sparseArrayVector = [0.1, 0.2, 0.3]; - const formattedSparseArrayVector = formatSearchVector( - sparseArrayVector, - DataType.SparseFloatVector - ); - expect(formattedSparseArrayVector).toEqual([sparseArrayVector]); - - const sparseArrayVectors = [ - [0.1, 0.2, 0.3], - [0.4, 0.5, 0.6], - ]; - const formattedSparseArrayVectors = formatSearchVector( - sparseArrayVectors, - DataType.SparseFloatVector - ); - expect(formattedSparseArrayVectors).toEqual(sparseArrayVectors); - - // sparse dict vector - const sparseDictVector = { 1: 2, 3: 4 }; - const formattedSparseDictVector = formatSearchVector( - sparseDictVector, - DataType.SparseFloatVector - ); - expect(formattedSparseDictVector).toEqual([sparseDictVector]); - - const sparseDictVectors = [ - { 1: 2, 3: 4 }, - { 1: 2, 3: 4 }, - ]; - const formattedSparseDictVectors = formatSearchVector( - sparseDictVectors, - DataType.SparseFloatVector - ); - expect(formattedSparseDictVectors).toEqual(sparseDictVectors); + }; + const formattedSparseCsrVector = formatSearchVector( + sparseCsrVector, + DataType.SparseFloatVector + ); + expect(formattedSparseCsrVector).toEqual([sparseCsrVector]); + + const sparseCsrVectors = [ + { + indices: [1, 3], + values: [2, 4], + }, + { + indices: [2, 4], + values: [3, 5], + }, + ]; + const formattedSparseCsrVectors = formatSearchVector( + sparseCsrVectors, + DataType.SparseFloatVector + ); + expect(formattedSparseCsrVectors).toEqual(sparseCsrVectors); + + // sparse array vector + const sparseArrayVector = [0.1, 0.2, 0.3]; + const formattedSparseArrayVector = formatSearchVector( + sparseArrayVector, + DataType.SparseFloatVector + ); + expect(formattedSparseArrayVector).toEqual([sparseArrayVector]); + + const sparseArrayVectors = [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + ]; + const formattedSparseArrayVectors = formatSearchVector( + sparseArrayVectors, + DataType.SparseFloatVector + ); + expect(formattedSparseArrayVectors).toEqual(sparseArrayVectors); + + // sparse dict vector + const sparseDictVector = { 1: 2, 3: 4 }; + const formattedSparseDictVector = formatSearchVector( + sparseDictVector, + DataType.SparseFloatVector + ); + expect(formattedSparseDictVector).toEqual([sparseDictVector]); + + const sparseDictVectors = [ + { 1: 2, 3: 4 }, + { 1: 2, 3: 4 }, + ]; + const formattedSparseDictVectors = formatSearchVector( + sparseDictVectors, + DataType.SparseFloatVector + ); + expect(formattedSparseDictVectors).toEqual(sparseDictVectors); + }); + + it('should build single search request correctly', () => { + // path + const milvusProtoPath = path.resolve( + __dirname, + '../../proto/proto/milvus.proto' + ); + const milvusProto = protobuf.loadSync(milvusProtoPath); + + const searchParams = { + collection_name: 'test', + data: [ + [1, 2, 3], + [4, 5, 6], + ], + expr: 'id > 0', + output_fields: ['*'], + }; + + const describeCollectionResponse = { + status: { error_code: 'Success', reason: '' }, + collection_name: 'test', + collectionID: 0, + consistency_level: 'Session', + num_partitions: '0', + aliases: [], + virtual_channel_names: {}, + physical_channel_names: {}, + start_positions: [], + shards_num: 1, + created_timestamp: '0', + created_utc_timestamp: '0', + properties: [], + db_name: '', + schema: { + name: 'test', + description: '', + enable_dynamic_field: false, + autoID: false, + fields: [ + { + name: 'id', + fieldID: '1', + dataType: 5, + is_primary_key: true, + description: 'id field', + data_type: 'Int64', + type_params: [], + index_params: [], + }, + { + name: 'vector', + fieldID: '2', + dataType: 101, + is_primary_key: false, + description: 'vector field', + data_type: 'FloatVector', + type_params: [{ key: 'dim', value: '3' }], + index_params: [], + }, + ], + }, + } as any; + + const searchRequest = buildSearchRequest( + searchParams, + describeCollectionResponse, + milvusProto + ); + expect(searchRequest.isHybridSearch).toEqual(false); + expect(searchRequest.request.collection_name).toEqual('test'); + expect(searchRequest.request.output_fields).toEqual(['*']); + expect(searchRequest.request.consistency_level).toEqual('Session'); + expect(searchRequest.nq).toEqual(2); + const searchParamsKeyValuePairArray = (searchRequest.request as any) + .search_params; + + // transform key value to object + const searchParamsKeyValuePairObject = searchParamsKeyValuePairArray.reduce( + (acc: any, { key, value }: any) => { + acc[key] = value; + return acc; + }, + {} + ); + + expect(searchParamsKeyValuePairObject.anns_field).toEqual('vector'); + expect(searchParamsKeyValuePairObject.params).toEqual('{}'); + expect(searchParamsKeyValuePairObject.topk).toEqual(100); + expect(searchParamsKeyValuePairObject.offset).toEqual(0); + expect(searchParamsKeyValuePairObject.metric_type).toEqual(''); + expect(searchParamsKeyValuePairObject.ignore_growing).toEqual(false); + }); + + it('should build hybrid search request correctly', () => { + // path + const milvusProtoPath = path.resolve( + __dirname, + '../../proto/proto/milvus.proto' + ); + const milvusProto = protobuf.loadSync(milvusProtoPath); + + const searchParams = { + collection_name: 'test', + data: [ + { + data: [1, 2, 3, 4, 5, 6, 7, 8], + anns_field: 'vector', + params: { nprobe: 2 }, + }, + { + data: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + anns_field: 'vector1', + }, + ], + limit: 2, + output_fields: ['vector', 'vector1'], + }; + + const describeCollectionResponse = { + status: { error_code: 'Success', reason: '' }, + collection_name: 'test', + collectionID: 0, + consistency_level: 'Session', + num_partitions: '0', + aliases: [], + virtual_channel_names: {}, + physical_channel_names: {}, + start_positions: [], + shards_num: 1, + created_timestamp: '0', + created_utc_timestamp: '0', + properties: [], + db_name: '', + schema: { + name: 'test', + description: '', + enable_dynamic_field: false, + autoID: false, + fields: [ + { + name: 'id', + fieldID: '1', + dataType: 5, + is_primary_key: true, + description: 'id field', + data_type: 'Int64', + type_params: [], + index_params: [], + }, + { + name: 'vector', + fieldID: '2', + dataType: 101, + is_primary_key: false, + description: 'vector field', + data_type: 'FloatVector', + type_params: [{ key: 'dim', value: '3' }], + index_params: [], + }, + { + name: 'vector1', + fieldID: '2', + dataType: 101, + is_primary_key: false, + description: 'vector field2', + data_type: 'FloatVector', + type_params: [{ key: 'dim', value: '3' }], + index_params: [], + }, + ], + }, + } as any; + + const searchRequest = buildSearchRequest( + searchParams, + describeCollectionResponse, + milvusProto + ); + console.dir(searchRequest, { depth: null }); + expect(searchRequest.isHybridSearch).toEqual(true); + expect(searchRequest.request.collection_name).toEqual('test'); + expect(searchRequest.request.output_fields).toEqual(['vector', 'vector1']); + expect(searchRequest.request.consistency_level).toEqual('Session'); + expect(searchRequest.nq).toEqual(1); + + (searchRequest.request as any).requests.forEach( + (request: any, index: number) => { + const searchParamsKeyValuePairArray = request.search_params; + + // transform key value to object + const searchParamsKeyValuePairObject = + searchParamsKeyValuePairArray.reduce( + (acc: any, { key, value }: any) => { + acc[key] = value; + return acc; + }, + {} + ); + + if (index === 0) { + expect(searchParamsKeyValuePairObject.anns_field).toEqual('vector'); + expect(searchParamsKeyValuePairObject.params).toEqual('{"nprobe":2}'); + expect(searchParamsKeyValuePairObject.topk).toEqual(2); + } else { + expect(searchParamsKeyValuePairObject.anns_field).toEqual('vector1'); + expect(searchParamsKeyValuePairObject.params).toEqual('{}'); + expect(searchParamsKeyValuePairObject.topk).toEqual(2); + } + } + ); + }); });