\newcommand{\R}{\mathbb{R}} \newcommand{\costR}{\mathcal{R}} AD をさらに詳しく掘り下げるために、n 個の入力から一つの出力を伴う関数 f、即ち f:\R^n\rightarrow \R のコンピュータへの実装を考えましょう。AD はベクトル型の関数にも適用できますが、簡単のため以下では、実数値をとる関数の場合のみを考えることとします。
AD はフォワードおよびリバースの2つのモードを持ちます。
f に対するフォワード(あるいは tangent-linear)AD バージョンは、入力 \mathbf{x},\mathbf{x^{(1)}}\in\R^n に対し、以下で与えられる関数 F^{(1)}:\R^{2n}\rightarrow \R です。
\begin{equation*} y^{(1)} = F^{(1)}(\mathbf{x},\mathbf{x^{(1)}}) = \nabla f(\mathbf{x}) \cdot \mathbf{x^{(1)}} = \left(\frac{\partial f}{\partial \mathbf{x}} \right) \cdot \mathbf{x^{(1)}} \end{equation*}
ここで、ドットは通常の内積を示します。f の全グラジェントを得るために、\mathbf{x^{(1)}} をカルテシアン座標基底ベクトルの全範囲に対応させ、F^{(1)} を繰り返し呼び出します。
- F^{(1)} の計算時間は F の計算時間のほぼ同程度である。
- 全グラジェント計算は f の計算時間のほぼ n 倍である。
- フォワードモード AD は有限差分と同程度のコストが掛かるが、マシン精度でグラジェントを計算する。
フォワードモード AD は通常、正確には微分される関数にも依りますが、n が 30 以下の小さい時に用いられます。
洞察
アジョイントモード AD を理解するために、入力 \mathbf{x}\in\R^n を、出力結果 y\in\R に至る、関数呼び出し列と並べて考えてみます(ここで \mathbf{x_m} は上で述べたカルテシアン座標基底ベクトルとは異なります)。
\mathbf{x} \overset{f_1}{\longrightarrow} \mathbf{x_1} \overset{f_2}{\longrightarrow} \mathbf{x_2} \longrightarrow \cdots \longrightarrow \mathbf{x_m} \overset{f_{m+1}}{\longrightarrow} y
我々の望みはグラジェント \partial y/\partial \mathbf{x} で、それは以下のようにチェーンルール(合成関数の微分則)を用いて得られるものです。
\frac{\partial \mathbf{x_1}}{\partial \mathbf{x}} \frac{\partial \mathbf{x_2}}{\partial \mathbf{x_1}} \frac{\partial \mathbf{x_3}}{\partial \mathbf{x_2}} \cdots \frac{\partial \mathbf{x_{m}}}{\partial \mathbf{x_{m-1}}} \frac{\partial y}{\partial \mathbf{x_m}}
数学的には
数学的にはこれをどのように計算しても問題ありません。通常は左から右へ計算します。
\left .\left(\cdots \left( \frac{\partial \mathbf{x_1}}{\partial \mathbf{x}} \frac{\partial \mathbf{x_2}}{\partial \mathbf{x_1}} \right) \frac{\partial \mathbf{x_3}}{\partial \mathbf{x_2}} \right) \cdots \frac{\partial \mathbf{x_{m}}}{\partial \mathbf{x_{m-1}}} \right) \frac{\partial y}{\partial \mathbf{x_m}}
このやり方は、プログラムは最初に \mathbf{x_1}、次に\mathbf{x_2} といった具合に、プログラムでの演算順序に相当するもので自然な考え方です。しかしながら、一般に各ヤコビアン \partial \mathbf{x_{i+1}}/\partial \mathbf{x_i} は行列のため、このやり方は最後の行列・ベクトル積以外はすべて行列・行列積となります。
そこで右側から始めてみましょう。
\frac{\partial \mathbf{x_1}}{\partial \mathbf{x}} \left( \frac{\partial \mathbf{x_2}}{\partial \mathbf{x_1}} \left( \frac{\partial \mathbf{x_3}}{\partial \mathbf{x_2}} \cdots \left( \frac{\partial \mathbf{x_{m}}}{\partial \mathbf{x_{m-1}}} \frac{\partial y}{\partial \mathbf{x_m}}\right)\cdots\right) \right.
すると、全てが行列・ベクトル積となり圧倒的に高速ですが、実際にはプログラムを逆に実行しなくてはいけません。つまり:
- \partial y/\partial \mathbf{x_m} の計算に必要なデータは計算の最後にしかない。それには y と \mathbf{x_m} が必要で、そのためには \mathbf{x_{m-1}} が必要、それには \mathbf{x_{m-2}} が必要・・・と、遡ってデータが必要になる
- この問題を解く一つの方法は、関連する中間的な値を保存しつつ、一旦プログラムを普通に実行させることである。
- そして今度は逆向きに、保存された中間値を用いてヤコビアン \frac{\partial \mathbf{x_{i+1}}}{\partial \mathbf{x_i}} を構成して、行列・ベクトル積を実行する。
これが AD のアジョイントモードと呼ばれるグラジェントの計算方法です。
アジョイントモデル
f のアジョイントモデルは、\R^n\!\times\!\R^n\!\times\!\R から \R^n への写像である関数 \mathbf{x_{(1)}} = F_{(1)}(\mathbf{x},\mathbf{x_{(1)}},y_{(1)}) で、以下の式で与えられます。
\begin{equation*} \mathbf{x_{(1)}} = \mathbf{x_{(1)}} + \nabla f(\mathbf{x}) \cdot y_{(1)} \end{equation*}
- y_{(1)} はスカラー関数である。ここで y_{(1)}=1 かつ \mathbf{x_{(1)}}=0 としてアジョイントモデル F_{(1)} を”一回”呼び出せば、f の偏微分の全ベクトルが計算される。
- ヤコビアン \partial \mathbf{x_{i+1}}/\partial \mathbf{x_i} は露わに構成されることはなく、そのスパース性を避けることが出来る。
- 一般に F_{(1)} の計算は、f の計算に必要な浮動小数点演算数の 5 倍を超えないことが証明できる。
- このことが、アジョイントモデルは全グラジェント計算のコストに対して、f の計算コストの R 倍(ここで R は小さな定数)で収まる事実を与える。
- しかしながら、アジョイントモデルの実装には、その計算よりも大変なデータフローの逆問題を解かなくてはならない。
- R の典型的な値は、コードに依存するが、5 から 50 の間である。
アジョイントモデルに必要なメモリ
アジョイントモデルの計算には、データフローの逆問題を解く必要があり、実際に逆向きの実行が必要となります。多くの AD ツール(dco を含む)は、この問題に対して、プログラムを普通に実行しながら、中間データを計算して "tape" と呼ばれるデータ構造のメモリに保存していくというアプローチをとります。比較的単純なコードであれば、tape は数 GB 程度ですが、本格的な製品プログラムでは多くの場合、大容量メモリを搭載したマシンでさえ容量が不足してしまいます。
この問題に対処するために dco は、コード中の様々なポイントへ、簡単にチェックポイントを挿入できるように、柔軟性のあるインターフェイスが用意されています。コードが逆向きに実行される場合、最後のチェックポイントは復元されて、その計算セクションが "tape" され逆向きに実行されます。次に最後から二番目のチェックポイントが復元されて、その計算セクションが(前回の計算結果を用いて)逆向きに実行され、以下同様にプレイバックされていきます。この方法では演算はメモリにとって代わり、tape のサイズはほぼ任意に制御可能です。
これは製品プログラムのアジョイントモデルを完璧に実行するための本質的な機能です。このチェックポイント機能や、アジョイントモデルコードのメモリ負荷低減のための他のテクニックについて、さらに情報を望まれる方は、nAG までご連絡ください。