Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
Signed-off-by: shanghaikid <[email protected]>
  • Loading branch information
shanghaikid committed Sep 15, 2024
1 parent 1cc6cbf commit f41d8e6
Showing 1 changed file with 286 additions and 71 deletions.
357 changes: 286 additions & 71 deletions test/utils/Format.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {
formatSearchResult,
Field,
formatSearchVector,
buildSearchRequest,
} from '../../milvus';

describe('utils/format', () => {
Expand Down Expand Up @@ -611,77 +612,291 @@ describe('utils/format', () => {
);
});

// sparse coo vector
const sparseCooVector = [
{ index: 1, value: 2 },
{ index: 3, value: 4 },
];
const formattedSparseCooVector = formatSearchVector(
sparseCooVector,
DataType.SparseFloatVector
);
expect(formattedSparseCooVector).toEqual([sparseCooVector]);

// sparse csr vector
const sparseCsrVector = {
indices: [1, 3],
values: [2, 4],
};
const formattedSparseCsrVector = formatSearchVector(
sparseCsrVector,
DataType.SparseFloatVector
);
expect(formattedSparseCsrVector).toEqual([sparseCsrVector]);

const sparseCsrVectors = [
{
it('should format sparse vectors correctly', () => {
// sparse coo vector
const sparseCooVector = [
{ index: 1, value: 2 },
{ index: 3, value: 4 },
];
const formattedSparseCooVector = formatSearchVector(
sparseCooVector,
DataType.SparseFloatVector
);
expect(formattedSparseCooVector).toEqual([sparseCooVector]);

// sparse csr vector
const sparseCsrVector = {
indices: [1, 3],
values: [2, 4],
},
{
indices: [2, 4],
values: [3, 5],
},
];
const formattedSparseCsrVectors = formatSearchVector(
sparseCsrVectors,
DataType.SparseFloatVector
);
expect(formattedSparseCsrVectors).toEqual(sparseCsrVectors);

// sparse array vector
const sparseArrayVector = [0.1, 0.2, 0.3];
const formattedSparseArrayVector = formatSearchVector(
sparseArrayVector,
DataType.SparseFloatVector
);
expect(formattedSparseArrayVector).toEqual([sparseArrayVector]);

const sparseArrayVectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
];
const formattedSparseArrayVectors = formatSearchVector(
sparseArrayVectors,
DataType.SparseFloatVector
);
expect(formattedSparseArrayVectors).toEqual(sparseArrayVectors);

// sparse dict vector
const sparseDictVector = { 1: 2, 3: 4 };
const formattedSparseDictVector = formatSearchVector(
sparseDictVector,
DataType.SparseFloatVector
);
expect(formattedSparseDictVector).toEqual([sparseDictVector]);

const sparseDictVectors = [
{ 1: 2, 3: 4 },
{ 1: 2, 3: 4 },
];
const formattedSparseDictVectors = formatSearchVector(
sparseDictVectors,
DataType.SparseFloatVector
);
expect(formattedSparseDictVectors).toEqual(sparseDictVectors);
};
const formattedSparseCsrVector = formatSearchVector(
sparseCsrVector,
DataType.SparseFloatVector
);
expect(formattedSparseCsrVector).toEqual([sparseCsrVector]);

const sparseCsrVectors = [
{
indices: [1, 3],
values: [2, 4],
},
{
indices: [2, 4],
values: [3, 5],
},
];
const formattedSparseCsrVectors = formatSearchVector(
sparseCsrVectors,
DataType.SparseFloatVector
);
expect(formattedSparseCsrVectors).toEqual(sparseCsrVectors);

// sparse array vector
const sparseArrayVector = [0.1, 0.2, 0.3];
const formattedSparseArrayVector = formatSearchVector(
sparseArrayVector,
DataType.SparseFloatVector
);
expect(formattedSparseArrayVector).toEqual([sparseArrayVector]);

const sparseArrayVectors = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
];
const formattedSparseArrayVectors = formatSearchVector(
sparseArrayVectors,
DataType.SparseFloatVector
);
expect(formattedSparseArrayVectors).toEqual(sparseArrayVectors);

// sparse dict vector
const sparseDictVector = { 1: 2, 3: 4 };
const formattedSparseDictVector = formatSearchVector(
sparseDictVector,
DataType.SparseFloatVector
);
expect(formattedSparseDictVector).toEqual([sparseDictVector]);

const sparseDictVectors = [
{ 1: 2, 3: 4 },
{ 1: 2, 3: 4 },
];
const formattedSparseDictVectors = formatSearchVector(
sparseDictVectors,
DataType.SparseFloatVector
);
expect(formattedSparseDictVectors).toEqual(sparseDictVectors);
});

it('should build single search request correctly', () => {
// path
const milvusProtoPath = path.resolve(
__dirname,
'../../proto/proto/milvus.proto'
);
const milvusProto = protobuf.loadSync(milvusProtoPath);

const searchParams = {
collection_name: 'test',
data: [
[1, 2, 3],
[4, 5, 6],
],
expr: 'id > 0',
output_fields: ['*'],
};

const describeCollectionResponse = {
status: { error_code: 'Success', reason: '' },
collection_name: 'test',
collectionID: 0,
consistency_level: 'Session',
num_partitions: '0',
aliases: [],
virtual_channel_names: {},
physical_channel_names: {},
start_positions: [],
shards_num: 1,
created_timestamp: '0',
created_utc_timestamp: '0',
properties: [],
db_name: '',
schema: {
name: 'test',
description: '',
enable_dynamic_field: false,
autoID: false,
fields: [
{
name: 'id',
fieldID: '1',
dataType: 5,
is_primary_key: true,
description: 'id field',
data_type: 'Int64',
type_params: [],
index_params: [],
},
{
name: 'vector',
fieldID: '2',
dataType: 101,
is_primary_key: false,
description: 'vector field',
data_type: 'FloatVector',
type_params: [{ key: 'dim', value: '3' }],
index_params: [],
},
],
},
} as any;

const searchRequest = buildSearchRequest(
searchParams,
describeCollectionResponse,
milvusProto
);
expect(searchRequest.isHybridSearch).toEqual(false);
expect(searchRequest.request.collection_name).toEqual('test');
expect(searchRequest.request.output_fields).toEqual(['*']);
expect(searchRequest.request.consistency_level).toEqual('Session');
expect(searchRequest.nq).toEqual(2);
const searchParamsKeyValuePairArray = (searchRequest.request as any)
.search_params;

// transform key value to object
const searchParamsKeyValuePairObject = searchParamsKeyValuePairArray.reduce(
(acc: any, { key, value }: any) => {
acc[key] = value;
return acc;
},
{}
);

expect(searchParamsKeyValuePairObject.anns_field).toEqual('vector');
expect(searchParamsKeyValuePairObject.params).toEqual('{}');
expect(searchParamsKeyValuePairObject.topk).toEqual(100);
expect(searchParamsKeyValuePairObject.offset).toEqual(0);
expect(searchParamsKeyValuePairObject.metric_type).toEqual('');
expect(searchParamsKeyValuePairObject.ignore_growing).toEqual(false);
});

it('should build hybrid search request correctly', () => {
// path
const milvusProtoPath = path.resolve(
__dirname,
'../../proto/proto/milvus.proto'
);
const milvusProto = protobuf.loadSync(milvusProtoPath);

const searchParams = {
collection_name: 'test',
data: [
{
data: [1, 2, 3, 4, 5, 6, 7, 8],
anns_field: 'vector',
params: { nprobe: 2 },
},
{
data: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
anns_field: 'vector1',
},
],
limit: 2,
output_fields: ['vector', 'vector1'],
};

const describeCollectionResponse = {
status: { error_code: 'Success', reason: '' },
collection_name: 'test',
collectionID: 0,
consistency_level: 'Session',
num_partitions: '0',
aliases: [],
virtual_channel_names: {},
physical_channel_names: {},
start_positions: [],
shards_num: 1,
created_timestamp: '0',
created_utc_timestamp: '0',
properties: [],
db_name: '',
schema: {
name: 'test',
description: '',
enable_dynamic_field: false,
autoID: false,
fields: [
{
name: 'id',
fieldID: '1',
dataType: 5,
is_primary_key: true,
description: 'id field',
data_type: 'Int64',
type_params: [],
index_params: [],
},
{
name: 'vector',
fieldID: '2',
dataType: 101,
is_primary_key: false,
description: 'vector field',
data_type: 'FloatVector',
type_params: [{ key: 'dim', value: '3' }],
index_params: [],
},
{
name: 'vector1',
fieldID: '2',
dataType: 101,
is_primary_key: false,
description: 'vector field2',
data_type: 'FloatVector',
type_params: [{ key: 'dim', value: '3' }],
index_params: [],
},
],
},
} as any;

const searchRequest = buildSearchRequest(
searchParams,
describeCollectionResponse,
milvusProto
);
console.dir(searchRequest, { depth: null });
expect(searchRequest.isHybridSearch).toEqual(true);
expect(searchRequest.request.collection_name).toEqual('test');
expect(searchRequest.request.output_fields).toEqual(['vector', 'vector1']);
expect(searchRequest.request.consistency_level).toEqual('Session');
expect(searchRequest.nq).toEqual(1);

(searchRequest.request as any).requests.forEach(
(request: any, index: number) => {
const searchParamsKeyValuePairArray = request.search_params;

// transform key value to object
const searchParamsKeyValuePairObject =
searchParamsKeyValuePairArray.reduce(
(acc: any, { key, value }: any) => {
acc[key] = value;
return acc;
},
{}
);

if (index === 0) {
expect(searchParamsKeyValuePairObject.anns_field).toEqual('vector');
expect(searchParamsKeyValuePairObject.params).toEqual('{"nprobe":2}');
expect(searchParamsKeyValuePairObject.topk).toEqual(2);
} else {
expect(searchParamsKeyValuePairObject.anns_field).toEqual('vector1');
expect(searchParamsKeyValuePairObject.params).toEqual('{}');
expect(searchParamsKeyValuePairObject.topk).toEqual(2);
}
}
);
});
});

0 comments on commit f41d8e6

Please sign in to comment.