Skip to content

Commit

Permalink
Merge pull request #12 from shapelets/feature/findBestNOccurrences
Browse files Browse the repository at this point in the history
Add mass and findBestNOccurrences functions
  • Loading branch information
jrecuerda authored Jun 10, 2019
2 parents e6610fd + f9811b8 commit a2f2bb4
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 2 deletions.
70 changes: 70 additions & 0 deletions +khiva/Matrix.m
Original file line number Diff line number Diff line change
Expand Up @@ -167,5 +167,75 @@
profile = khiva.Array(profileRef);
index = khiva.Array(indexRef);
end

function distances = mass(query, tss)
% Mueen's Algorithm for Similarity Search.
%
% The result has the following structure:
% - 1st dimension corresponds to the index of the subsequence in
% the time series.
% - 2nd dimension corresponds to the number of queries.
% - 3rd dimension corresponds to the number of time series.
%
% For example, the distance in the position (1, 2, 3) correspond
% to the distance of the third query to the fourth time series
% for the second subsequence in the time series.
%
% [1] Yan Zhu, Zachary Zimmerman, Nader Shakibay Senobari,
% Chin-Chia Michael Yeh, Gareth Funning, Abdullah Mueen,
% Philip Brisk and Eamonn Keogh (2016). Matrix Profile II:
% Exploiting a Novel Algorithm and GPUs to break
% the one Hundred Million Barrier for Time Series Motifs and Joins.
% IEEE ICDM 2016.
%
% *query* KHIVA Array whose first dimension is the length of the
% query time series and the second dimension is the number of
% queries.
%
% *tss* KHIVA Array whose first dimension is the length of the
% time series and the second dimension is the number of time
% series.
%
% *distances* KHIVA Array with the distances.
distancesRef = libpointer('voidPtrPtr');
[~, ~, distancesRef] = calllib('libkhivac', ...
'mass', query.getReference(), tss.getReference(), ...
distancesRef);
distances = khiva.Array(distancesRef);
end

function [distances, indexes] = findBestNOccurrences(query, tss, n)
% Calculates the N best matches of several queries in several time series.
%
% The result has the following structure:
% - 1st dimension corresponds to the nth best match.
% - 2nd dimension corresponds to the number of queries.
% - 3rd dimension corresponds to the number of time series.
%
% For example, the distance in the position (1, 2, 3) corresponds to the
% second best distance of the third query in the fourth time series.
% The index in the position (1, 2, 3) is the is the index of the
% subsequence which leads to the second best distance of the third query
% in the fourth time series.
%
% *query* KHIVA Array whose first dimension is the length of the query
% time series and the second dimension is the number of queries.
%
% *tss* KHIVA Array whose first dimension is the length of the time
% series and the second dimension is the number of time series.
%
% *n* Number of matches to return.
%
% *distances* KHIVA Arrays with the distances.
%
% *indexes* KHIVA Arrays with the indexes.
distancesRef = libpointer('voidPtrPtr');
indexesRef = libpointer('voidPtrPtr');
[~, ~, ~, distancesRef, indexesRef] = calllib('libkhivac', ...
'find_best_n_occurrences', query.getReference(), ...
tss.getReference(), n, distancesRef, indexesRef);
distances = khiva.Array(distancesRef);
indexes = khiva.Array(indexesRef);
end
end
end
55 changes: 53 additions & 2 deletions tests/MatrixUnitTests.m
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ function testStomp(testCase)
profileHost = profile.getData();
indexHost = index.getData();
diffProfile = abs(profileHost - expectedProfile);
testCase.verifyLessThanOrEqual(diffProfile, 1e-2);
testCase.verifyLessThanOrEqual(diffProfile, 2e-2);
testCase.verifyEqual(indexHost, expectedIndex);
end

Expand All @@ -168,8 +168,59 @@ function testStompSelfJoin(testCase)
profileHost = profile.getData();
indexHost = index.getData();
diffProfile = abs(profileHost - expectedProfile);
testCase.verifyLessThanOrEqual(diffProfile, 1e-2);
testCase.verifyLessThanOrEqual(diffProfile, 2e-2);
testCase.verifyEqual(indexHost, expectedIndex);
end

function testMass(testCase)
q = khiva.Array(single([4, 3, 8]'));
t = khiva.Array(single([10, 10, 10, 11, 12, 11, 10, 10, 11, 12, 11, 14, 10, 10]'));
distancesArray = khiva.Matrix.mass(q, t);
distances = distancesArray.getData();
expected = [1.732051, 0.328954, 1.210135, 3.150851, 3.245858, 2.822044, ...
0.328954, 1.210135, 3.150851, 0.248097, 3.30187, 2.82205]';
diffDistances = abs(distances - expected);
testCase.verifyLessThanOrEqual(diffDistances, 1e-3);
end

function testMassMultiple(testCase)
q = khiva.Array(single([[10, 10, 11, 11]', [10, 11, 10, 10]']));
t = khiva.Array(single([[10, 10, 10, 11, 12, 11, 10]',...
[10, 11, 12, 11, 14, 10, 10]']));
distancesArray = khiva.Matrix.mass(q, t);
distances = distancesArray.getData();
expected = zeros([4, 2, 2]);
expected(:, :, 1) = [[1.83880341, 0.87391543, 1.5307337, 3.69551826]', ...
[3.26598597, 3.48967957, 2.82842779, 1.21162188]'];
expected(:, :, 2) = [[1.5307337, 2.17577887, 2.57832384, 3.75498915]', ...
[2.82842779, 2.82842731, 3.21592307, 0.50202721]'];

diffDistances = abs(distances - expected);
testCase.verifyLessThanOrEqual(diffDistances, 1e-4);
end

function testFindBestNOccurrences(testCase)
q = khiva.Array(single([10, 11, 12]'));
t = khiva.Array(single([[10, 10, 11, 11, 12, 11, 10, 10, 11, 12, 11, 10, 10, 11]',...
[10, 10, 11, 11, 12, 11, 10, 10, 11, 12, 11, 10, 10, 11]']));
[distancesArray, indexesArray]= khiva.Matrix.findBestNOccurrences(q, t, 1);
distances = distancesArray.getData();
indexes = indexesArray.getData();

testCase.verifyLessThanOrEqual(distances(1), 1e-2);
testCase.verifyLessThanOrEqual(abs(indexes(1) - 7), 1e-4);
end

function testFindBestNOccurrencesMultipleQueries(testCase)
q = khiva.Array(single([[11, 11, 10, 11]', [10, 11, 11, 12]']));
t = khiva.Array(single([[10, 10, 11, 11, 10, 11, 10, 10, 11, 11, 10, 11, 10, 10]', ...
[11, 10, 10, 11, 10, 11, 11, 10, 11, 11, 14, 10, 11, 10]']));
[distancesArray, indexesArray]= khiva.Matrix.findBestNOccurrences(q, t, 4);
distances = distancesArray.getData();
indexes = indexesArray.getData();

testCase.verifyLessThanOrEqual(abs(distances(3,1,2) - 1.83880329), 1e-4);
testCase.verifyLessThanOrEqual(abs(indexes(4, 2, 1) - 2), 1e-4);
end
end
end

0 comments on commit a2f2bb4

Please sign in to comment.