Skip to content

Commit

Permalink
Re #1786 added possibility to use uint64 keys in fast_map and unit te…
Browse files Browse the repository at this point in the history
…sts for them.
  • Loading branch information
abuts committed Feb 18, 2025
1 parent 4bf3668 commit 01a28b1
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 11 deletions.
39 changes: 37 additions & 2 deletions _test/test_utilities_herbert/fast_map_vs_map_performance.m
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
% more keys then necessary to ensure sufficient unique keys pool.
base_key = unique(base_key); % ensure absence of duplicated keys


n_keys = min(n_keys,numel(base_key));
base_key = base_key(1:n_keys); % leave the expected number of keys
base_val = 1:numel(base_key);
keysUint = uint32(base_key); % convert them into requested type.
mm = min_max(keysUint) % display range of keys used in tests

Expand Down Expand Up @@ -57,7 +59,7 @@
if fm.isKey(test_keys(i))
idx1 = fm.get(test_keys(i));
else
idx = fm.n_members;
idx = fm.n_members;
fm = fm.add(test_keys(i),idx+1);
end
end
Expand Down Expand Up @@ -97,10 +99,43 @@
fprintf('Find keys in FAST MAP opt map takes %gsec\n',tv)

% Measure access time for optimized fast_map using remapper method for all keys
fm = fast_map(base_key,1:numel(base_key));
fm = fast_map(base_key,base_val);
fm.optimized = true;
tv = tic;
idx1 = fm.get_values_for_keys(test_keys,true);
tv = toc(tv);
fprintf('Find all keys in FAST MAP opt map takes %gsec\n',tv)
%--------------------------------------------------------------------------
% Measure access time for fast_map with uint64 keys
test_keys = uint64(test_keys);
fm = fast_map(uint64(base_key),base_val);
fm.optimized = false;
tv = tic;
for i=1:n_idx
idx1 = fm.get(test_keys(i));
end
tv = toc(tv);
fprintf('Find keys one-by-one in FAST uint64 MAP takes %gsec\n',tv)

% Measure access time for fast_map using remapper method for all keys
tv = tic;
fm.optimized = false;
idx1 = fm.get_values_for_keys(test_keys);
tv = toc(tv);
fprintf('Find all keys in FAST uint64 MAP non-opt takes %gsec\n',tv)

% Measure access time for optimized fast_map
fm.optimized = true;
tv = tic;
for i=1:n_idx
idx1 = fm.get(test_keys(i));
end
tv = toc(tv);
fprintf('Find keys one-by-one in FAST uint64 opt MAP takes %gsec\n',tv)

% Measure access time for optimized fast_map using remapper method for all keys
fm.optimized = true;
tv = tic;
idx1 = fm.get_values_for_keys(test_keys,true);
tv = toc(tv);
fprintf('Find all keys in FAST uint64 MAP optimized takes %gsec\n',tv)
24 changes: 17 additions & 7 deletions _test/test_utilities_herbert/test_fast_map.m
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
end
%------------------------------------------------------------------
%------------------------------------------------------------------
function test_set_different_key_type(~)
fm = fast_map();
fm.KeyType = uint64(1);
fm = fm.add(10,1);

assertEqual(fm.KeyType,'uint64');
assertEqual(fm.keys,uint64(10));
assertEqual(fm.values,1);
end
%------------------------------------------------------------------
function test_get_all_val_for_keys_optimized_no_checks(~)
n_keys = 100;
base_key = 10+round(rand(1,10*n_keys)*(10*n_keys-1));
Expand All @@ -23,10 +33,10 @@ function test_get_all_val_for_keys_optimized_no_checks(~)
fm.optimized = true;

valm = fm.get_values_for_keys(base_key,false);

assertEqual(val,valm);
end

function test_get_all_val_for_keys_optimized_with_checks(~)
n_keys = 100;
base_key = 10+round(rand(1,10*n_keys)*(10*n_keys-1));
Expand All @@ -38,10 +48,10 @@ function test_get_all_val_for_keys_optimized_with_checks(~)
fm.optimized = true;

valm = fm.get_values_for_keys(base_key,false);

assertEqual(val,valm);
end

function test_get_all_val_for_keys(~)
n_keys = 100;
base_key = 10+round(rand(1,10*n_keys)*(10*n_keys-1));
Expand All @@ -51,10 +61,10 @@ function test_get_all_val_for_keys(~)

fm = fast_map(base_key,val);
valm = fm.get_values_for_keys(base_key);

assertEqual(val,valm);
end
%------------------------------------------------------------------
%------------------------------------------------------------------
function test_insertion_in_optimized(~)
n_keys = 100;
base_key = 10+round(rand(1,10*n_keys)*(10*n_keys-1));
Expand Down Expand Up @@ -129,7 +139,7 @@ function test_map_constrcutrion_from_cellarray(~)
end
function test_fast_map_construction(~)
val = 10:-1:1;
fm = fast_map(1:10,val );
fm = fast_map(uint32(1:10),val );
assertEqual(fm.keys,uint32(1:10));
assertEqual(fm.values,val);

Expand Down
34 changes: 32 additions & 2 deletions herbert_core/utilities/classes/@fast_map/fast_map.m
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@
return;
end
obj.do_check_combo_arg_ = false;
if iscell(keys)
obj.KeyType = class(keys{1});
else
obj.KeyType = class(keys);
end
obj.keys = keys;
obj.values = values;
obj.do_check_combo_arg_ = true;
Expand Down Expand Up @@ -129,6 +134,24 @@
function kt = get.KeyType(obj)
kt = obj.key_type_;
end
function obj = set.KeyType(obj,type)
if isnumeric(type)
type = class(type);
end
switch(type)
case('uint32')
obj.key_conv_handle_ = @uint32;
case('uint64')
obj.key_conv_handle_ = @uint64;
case('double')
obj.key_conv_handle_ = @double;
otherwise
error('HORACE:fast_map:invalid_argument', ...
'Type %s as fast map key is not yet supported',type)
end
obj.keys_ = obj.key_conv_handle_(obj.keys_);
obj.key_type_ = type;
end
%
function ks = get.keys(obj)
ks = obj.keys_;
Expand Down Expand Up @@ -326,13 +349,13 @@
% and .sqw data format. Each new version would presumably read
% the older version, so version substitution is based on this
% number
ver = 1;
ver = 2;
end

function flds = saveableFields(~)
% get independent fields, which fully define the state of the
% serializable object.
flds = {'keys','values'};
flds = {'keys','values','KeyType'};
end
%
function obj = check_combo_arg(obj)
Expand All @@ -343,6 +366,13 @@
obj = check_combo_arg_(obj);
end
end
methods(Access=protected)
function [S,obj] = convert_old_struct (obj, S, ver)
if ver == 1
S.KeyType = 'uint32';
end
end
end

methods(Static)
function obj = loadobj(S)
Expand Down

0 comments on commit 01a28b1

Please sign in to comment.