n-th Tensorの距離行列計算を高速でしたい
目次
クイックサマリ
- 手持ちのデータが例えば3rd Tensorだったときに、距離行列を計算する方法を紹介
- 最大で310倍はやくなった
- Numpy, Pytorchで試してみた
距離行列?そもそもなんで必要?
機械学習ではよくカーネル関数を使う場面に出くわします。古典的にはSVMとか。最近ではGP: Gaussian Processも流行りです。ぼくの同僚たちが研究してるのはGPです。ちょっとウラヤマシイ。もちろん、深層学習もカーネル関数を使うことが「よく」あります。
で、カーネル関数の多くは距離が式の中に入ってます。例えば、ガウスカーネルだと次の式になります。

sigma^2はバンド幅で || x -y ||² が距離ですね。
xとyがスカラ値なら || x – y ||^2 は簡単な話でしょう。でも、もしx, yがベクタだったら?x, yが行列だったら?どうします?
2重ループがもっとも簡単
2重ループで計算しちゃうのが直感的で簡単な実装です。例えば、X, Yを次のように定義しましょう。X, Yはm個のサンプルを持つ集合です。

距離行列Dはxとyのデータ間距離を意味します。つまり、Dはm行m列の行列です。

じゃあ、2重ループでDを求めるコードを書いてみましょう。
1 2 3 4 5 6 |
import numpy as np D = np.zeros((m, m)) for i, x in enumerate(X): for j, y in enumerate(Y): d = (x - y) ** 2 D[i, j] = d |
x, yがどんな形状(ベクトルだろうが、行列だろうが)であっても、基本的には2重ループで計算できます。
ボトルネック: Speed
2重ループは簡単ですが、計算速度が問題になります。そもそもpythonのループは遅い。
もし、x, yがベクトル、つまり、X, Yが行列の場合(x, yが1-tensorでX, Yが2nd-tensor)は次のようなコードで高速化できます。高速化というか、線形代数的にガウスカーネルの式を置き換えしただけです。
1 2 3 4 |
XY = np.dot(X, Y.T) X_sqnorms = np.diagonal(XX) Y_sqnorms = np.diagonal(YY) D_XY = -2 * XY + X_sqnorms[:, np.newaxis] + Y_sqnorms[np.newaxis, :] |
じゃあ、次にx, yが行列の場合を考えてみましょう。つまり、X: (100, 64, 12), Y: (100, 64, 12)です。XとYは100個のサンプルを持つ集合です。1サンプルは(64, 12)の大きさをした行列です。数学的に書くとこうなります。

こうなると、上のコードは使えません。なぜなら np.dot(X, Y.T)
っていう転置した二乗計算ができないからです。
で、探しに探すとStackoverflowにエレガントな解法がありました。この解法を使えば、基本的にどんなn-th tensorにでも線形な時間で計算可能です。
アイディアの核になるのは pdict()
関数です。これは距離を高速に計算してくれる実装ですが、入力はベクトルのみです。そこで、データの形を変形し、ベクトルにしてしまおうというわけです。ベクトルになったとしても、indexは一意に保たれますから、情報は可逆です。


速度比較
というわけで実装して比較してみました。次の3ケースを比較してます。
- with numpy, D(X, X) where X is (50, 64, 12).
- with pytorch, D(X, X) where X is (50, 64, 12).
- with pytorch, D(X, Y) where X, Y are (50, 64, 12).
コードはGistに載せておきました。どのケースにおいても2重for-loopとreshapeアプローチの解にほとんど差がないことを確認してます。
計算スピードを比較してみましょう。D(X, X)のPytorchケースでは実に350倍も早くなってます。
1 2 3 4 5 6 7 8 9 10 11 |
# With numpy D(X, X) euclidean_normal(x_array): 0.0548619776032865 euclidean_n_th_tensor(): 0.0010964786633849144 # With torch D(X, X) euclidean_normal_torch(x_array): 0.215167117677629 euclidean_n_th_tensor_torch(): 0.0006791732273995876 # With torch D(X, Y) euclidean_normal_torch_xy(x_torch, y_torch): 0.2175855808891356 euclidean_n_th_tensor_torch_xy(): 0.022157766856253147 |
おしまい
StackOverflowの人。ありがとう(^^ゞ
ディスカッション
コメント一覧
まだ、コメントがありません