【DTW】ソースコード
どうも、姉から怒られて現在の身辺状況や進捗を身内に報告するために実家へ帰ったら、実家でも怒られたり色々と細かい予定が決まってきて、てんやわんやしている最中です。
今日はデザインパターンの続きを書こうと思ったのですが、片方が風邪を引いてしまったので一日中家に居ることにしたので、大学に今学んでいる技術書を全部持って行ってしまっていたので、ネタに困ったので今回は簡単に書けるDTWについてのソースコードを見て行きましょう。
DTWとは?
Dynamic Time Warping の通称。
主に音声認識やパターン認識に使われる計算法でDPの一種である。
ここらはいつもどおりWikipedia大先生に丸投げ 【音声認識 - Wikipedia】
リンク先にDTWについてのリンクがあるけど英語なので省略。
ソースコード
小一時間程度で実装したのでバグがあるかもしれませんが、ソースはこんな感じです。
#include <cstdlib> #include <cmath> #include <vector> #include <queue> //! @note 今回は実装の手間を省くためにvectorを使っているけど、配列を使って配列のサイズを別な変数でやったほうが速度や分量的にはスマート class DTW { private: static const int INF = static_cast<int>(1e+6); std::vector<int> ref; std::vector<int> in; std::vector<std::vector<int> > D; /// @brief ここでは局所距離を音声の周波数の差(スカラ)とする int calcLocalDistance(int t, int tau) { return static_cast<int>(fabs(static_cast<double>(in.at(t) - ref.at(tau)))); } /// @brief 配列の範囲内か調べる bool isInRange(int t, int tau) { return 0 <= t && in.size() > t && 0 <= tau && ref.size() > tau; } public: /// @brief 参照・入力パターン、積み上げ値の初期化 DTW(std::vector<int> _ref, std::vector<int> _in):ref(_ref), in(_in) { D = std::vector<std::vector<int> >(ref.size(), std::vector<int>(in.size(), INF)); for(int t = 0; t < in.size(); ++t) { D[0][t] = 0; // 積み上げ値を初期条件に従い初期化 } } /// @brief 入力と参照パターンから積み上げ計算を行う void accumulate() { for(int t = 0; t < in.size(); ++t) { for(int tau = 0; tau < ref.size(); ++tau) { int nt, ntau; int d1, d2, d3; /// @note 配列に変量を入れてforで回してもいい // path1(時間方向に拡大) nt = t + 1; ntau = tau; d1 = (isInRange(nt, ntau)) ? D[tau][t] + calcLocalDistance(nt, ntau) : INF; // path2(参照、時間方向に拡大) nt = t + 1; ntau = tau + 1; d2 = (isInRange(nt, ntau)) ? D[tau][t] + calcLocalDistance(nt, ntau) : INF; // path3(参照方向に拡大) nt = t; ntau = tau + 1; d3 = (isInRange(nt, ntau)) ? D[tau][t] + calcLocalDistance(nt, ntau) : INF; // 最小のpathを採用 D[ntau][nt] = std::min(d1, std::min(d2, d3)); } } } /// @brief 積み上げ値でしきい値以下の最小の値をスポッティング認識する int spotting(int threshold) { const int TAU = ref.size() - 1; int tau = TAU; int t; int minD = INF; // 最小の積み上げ局所距離を探索 for(int tt = 0; tt < in.size(); ++tt) { if(minD > D[TAU][tt]) { minD = D[TAU][tt]; t = tt; } } return (minD < threshold) ? t : -1; } /// @brief 積み上げ値から逆演算を行うことでマッチングするパターンを見つけ出す std::queue<int> backtrack(int spottingTime) { std::queue<int> result; const int TAU = ref.size() - 1; int tau = TAU; int t = spottingTime; // 最初の値も結果なので保存 result.push(t); // 参照パターンの全フレームで逆演算を行う while(tau > 0) { int minD = INF; int tt = t, ttau = tau; int pt, ptau; int d1, d2, d3; // path1(時間方向に後退) pt = tt - 1; ptau = tau; if(isInRange(pt, ptau) && minD > D[ptau][pt]) { minD = D[ptau][pt]; t = pt; tau = ptau; } // path2(参照、時間方向に後退) pt = tt - 1; ptau = ttau - 1; if(isInRange(pt, ptau) && minD > D[ptau][pt]) { minD = D[ptau][pt]; t = pt; tau = ptau; } // path3(参照方向に後退) pt = tt; ptau = ttau - 1; if(isInRange(pt, ptau) && minD > D[ptau][pt]) { minD = D[ptau][pt]; t = pt; tau = ptau; } // 最終的に選択された時間を結果に格納 result.push(t); } return result; } }; int main(int argc, char** argv) { // 参照・入力パターン(ここでは音声)を読み込み、配列に離散化した連続な値として返す // *ここは本質ではないので実装を省略 std::vector<int> ref = loadSoundData(argv[1]); std::vector<int> in = loadSoundData(argv[2]); int threshold = atoi(argv[3]); // しきい値 DTW dtw(ref, in); dtw.accumulate(); int spottingTime = dtw.spotting(threshold); if(spottingTime == -1) exit(1); std::queue<int> result = dtw.backtrack(spottingTime); return 0; }