Skip to content

Commit

Permalink
randomized dtw algorithm
Browse files Browse the repository at this point in the history
derohde committed Nov 17, 2021
1 parent 368e191 commit 9021bd0
Showing 3 changed files with 74 additions and 0 deletions.
4 changes: 4 additions & 0 deletions include/dynamic_time_warping.hpp
Original file line number Diff line number Diff line change
@@ -10,10 +10,13 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI

#pragma once

#include <map>

#include "types.hpp"
#include "point.hpp"
#include "interval.hpp"
#include "curve.hpp"
#include "random.hpp"

namespace Dynamic_Time_Warping {

@@ -27,6 +30,7 @@ namespace Discrete {
};

Distance distance(const Curve&, const Curve&);
Distance distance_randomized(const Curve&, const Curve&);
}

}
69 changes: 69 additions & 0 deletions src/dynamic_time_warping.cpp
Original file line number Diff line number Diff line change
@@ -53,6 +53,75 @@ Distance distance(const Curve &curve1, const Curve &curve2) {
return result;
}

Distance distance_randomized(const Curve &curve1, const Curve &curve2) {
Distance result;
const auto start = std::clock();

std::priority_queue<std::pair<distance_t, std::pair<curve_size_t, curve_size_t>>> queue;
std::map<std::pair<curve_size_t, curve_size_t>, bool> seen;
distance_t cost = std::numeric_limits<distance_t>::infinity();

Point min_coords(curve1.dimensions()), max_coords(curve1.dimensions());
for (curve_size_t i = 0; i < curve1.complexity(); ++i) {
for (dimensions_t j = 0; j < curve1.dimensions(); ++j) {
min_coords[j] = std::min(min_coords[j], curve1[i][j]);
max_coords[j] = std::max(max_coords[j], curve1[i][j]);
}
}
for (curve_size_t i = 0; i < curve2.complexity(); ++i) {
for (dimensions_t j = 0; j < curve1.dimensions(); ++j) {
min_coords[j] = std::min(min_coords[j], curve2[i][j]);
max_coords[j] = std::max(max_coords[j], curve2[i][j]);
}
}

const distance_t w = min_coords.dist(max_coords);
distance_t d;

auto ugen = Random::Uniform_Random_Generator<>();

queue.emplace(-curve1[0].dist(curve2[0]), std::make_pair(0, 0));

while (not queue.empty()) {
auto current = queue.top();
queue.pop();

auto i = current.second.first;
auto j = current.second.second;
auto curr_cost = -current.first;

if (i == curve1.complexity() - 1 and j == curve2.complexity() - 1) {
cost = std::min(cost, curr_cost);
}

if (not seen[current.second]) {
seen[current.second] = true;
if (i < curve1.complexity() - 1) {
if (j < curve2.complexity() - 1) {
d = curve1[i+1].dist(curve2[j+1]);
if (ugen.get() > d / w)
queue.emplace(-(curr_cost + d), std::make_pair(i+1, j+1));

d = curve1[i].dist(curve2[j+1]);
if (ugen.get() > d / w)
queue.emplace(-(curr_cost + d), std::make_pair(i, j+1));
}
d = curve1[i+1].dist(curve2[j]);
if (ugen.get() > d / w)
queue.emplace(-(curr_cost + d), std::make_pair(i+1, j));
} else if (j < curve2.complexity() - 1) {
d = curve1[i].dist(curve2[j+1]);
if (ugen.get() > d / w)
queue.emplace(-(curr_cost + d), std::make_pair(i, j+1));
}
}
}
const auto end = std::clock();
result.time = (end - start) / CLOCKS_PER_SEC;
result.value = cost;
return result;
}

} // end namespace Discrete

} // end namespace Dynamic Time Warping
1 change: 1 addition & 0 deletions src/fred_python_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -162,6 +162,7 @@ PYBIND11_MODULE(backend, m) {
m.def("continuous_frechet", &fc::distance);
m.def("discrete_frechet", &fd::distance);
m.def("discrete_dynamic_time_warping", &ddtw::distance);
m.def("discrete_dynamic_time_warping_randomized", &ddtw::distance_randomized);

m.def("minimum_error_simplification", &minimum_error_simplification);
m.def("approximate_minimum_link_simplification", &approximate_minimum_link_simplification);

0 comments on commit 9021bd0

Please sign in to comment.