MATHGRAM

主に数学とプログラミング、時々趣味について。

RustでForward自動微分を実装してみた

Rustの理解がまだまだすぎて. モジュールとして全く使えない実装になっているので悪しからず. Forward自動微分そのものを理解することを目標とします.

ちなみに僕が参考にしたのは以下のサイトです.

ありがとうございました.

自動微分とは

自動微分そのものの説明は上にあげたサイトがとても詳しく説明しています. 数値微分を理解している人ならば, 簡単に違いを理解できると思います.

自分なりに自動微分を簡単にまとめると,
導関数を先に定義して計算効率&精度を上げようって感じです.
例えば,  \sin{x}微分 \cos{x}である, と先に定義しておくわけですね.

参考にしたサイトと同じようなことを説明しても意味はないと思うので, ここでは例を複数あげることで理解の助けになれればと思います.

とりあえず, Forward自動微分の実装で重要なのは, 以下の3点だと自分は理解しました.

  • 二重数と呼ばれる, Dual型の定義
  • Dual型に対する基本演算の定義
  • Dual型の初期値の設定

1つずつ詳しく見ていくことで, どの言語でも実装できるような説明を心がけてみます.

Dual型の定義

Dual型はある変数xとその微小量を表すdxをもつ型です.
Rustでは以下のように実装しています.

#[derive(Debug, Copy, Clone)]
pub struct Dual {
    var: f32,
    eps: f32,
}

みてわかる通り, 普通の変数に微小量がくっついただけなので, int型やfloat型の拡張と考えていいと思います.

基本演算の定義

では, 次にDual型に対する基本演算を定義します.
というより, 微小量 dxに関する基本演算を定義します.

例として, まずはを定義してみましょう.

こんな式があったとします.

 y = x + \sin{x}

突然ですが, この y xに関して微分して見てください.

・・・はい. 恐らくみなさんの頭の中では第1項, 第2項をそれぞれ微分して最後に和をとる, というような暗算をしただろうと思います.

より詳しくいうのであれば, まず第1項を計算する.

\displaystyle \frac{d}{dx} x = 1

次に第2項も微分する.

\displaystyle \frac{d}{dx} \sin{x} = \cos{x}

そして最後に和をとる.

\displaystyle \frac{d}{dx}y = 1 +  \cos{x}

以上のようなステップを踏むことで, y xに関する微分を導いたと思います.

ここで重要なのは最後に和をとったことです.

つまりyの微小量は, それぞれの項の微小量を単に足すことで求めることができます.

より簡潔に言うならば, 和の微分微分の和ということです.

以上のルールに則りRustでDual型の和を定義するとこのようになります.

// + 演算子のオーバーロード
impl Add for Dual {
    type Output = Dual;
    fn add(self, r: Dual) -> Dual {
        Dual {
            var: self.var + r.var,
            //ある変数の和に対する微小量の和を定義する.
            eps: self.eps + r.eps
        }
    }
}

それでは次にを定義をしてみましょう.
先ほどと同じように具体例を出してみます.

 y = x  \sin{x}

和の時と同様にx微分してみてください.

・・・はい. 今度は高校の時に呪文のように覚えた微分そのまま, そのまま微分のルールに基づいて暗算したのではないでしょうか. (この呪文は僕だけかもしれませんが…)

具体的に書いてみると,

 \displaystyle
\frac{d}{dx}y = \sin{x}\frac{d}{dx}x + x \frac{d}{dx}\sin{x}

こうですね. それではこのルールに則りDual型の積, もといDual型の微小量に対する積を定義しましょう.

// 積のオーバーロード
impl Mul for Dual {
    type Output = Dual;
    fn mul(self, r: Dual) -> Dual {
        Dual {
            var: self.var * r.var,
            // ある変数の積に対する微小量の積を定義する.
            eps: self.eps*r.var + self.var*r.eps
        }
    }
}

以上のように差や商に関してもDual型の演算を定義してあげることで, 勝手に微小量が計算されちゃうよっていう寸法です.

ここで簡単な式に実践してみましょう.

 y = x^2 + x

この式の微分

 \displaystyle
\frac{d}{dx}y = 2x + 1

ですね. 簡単です. この式によると x = 2の点での傾きは5です. さてDual型を使ってこの5は計算できるのか.

// 式の定義
fn example1(x: Dual) -> Dual {
    x*x + x
}

fn main(){
    // x = 2 なので varは 2 とします.
    let x = Dual{var: 2f32, eps: 1f32};

    println!("{:?}", example1(x));
}

出力

Dual { var: 6, eps: 5 }

おーちゃんと計算できてますね!

先ほど説明した四則演算のように,  \sin \expも定義してあげればこれらの演算が入っている式でも問題なく傾きを求めることができます.

こんな感じでざっと定義してあげて…

impl Dual {
    fn sin(self) -> Dual {
        Dual {
            var: self.var.sin(),
            eps: self.eps*self.var.cos()
        }
    }

    fn exp(self) -> Dual {
        Dual {
            var: self.var.exp(),
            eps: self.eps*self.var.exp()
        }
    }
}

こいつらを含んだ式を適当に作ってあげて…

 y = \sin{x} + xe^{x}
fn example2(x: Dual) -> Dual {
    x.sin() + x*x.exp()
}

 x = 0の傾きを求める!

 \displaystyle \frac{d}{dx} y(0) = \cos{0} + e^{0} + 0e^{0} = 2
fn main(){
    // x = 0
    let x = Dual{var: 0f32, eps: 1f32};
    
    println!("{:?}", example2(x));
}

出力じゃ!!

Dual { var: 0, eps: 2 }

おもしれェェェエエエ!

仕組みは単純なのに簡単に傾きが求められる!やばい!

初期値の設定

ここでより一般的な拡張を考えてみると,

おいおい, 多変数関数の時はどうすんだい?

ってなりますよね. 実はForward modeは多変数関数に弱いんです. でもDual型の初期値の設定によって一応求めることができます.

例えばこんな式を用意しましょう.

 \displaystyle z(x, y) = \sin{y} + xy + ye^{x}

そして z xによる偏微分を求めてみる. 手計算するとこうです.

 \displaystyle
\frac{\partial}{\partial x} z(x, y) = y + ye^{x}

 xに関する傾きを求めるには次のように初期値を設定してあげましょう.

let x = Dual{var: 0f32, eps: 1f32};
let y = Dual{var: 2f32, eps: 0f32}; // epsを0にする!

つまり y xに対する微小量は0と, 定義してあげます. 上のように定義すると出力は

Dual { var: 2.9092975, eps: 4 }

こうなります. 実際,

 \displaystyle
\frac{\partial}{\partial x} z(0, 2) = 2 + 2e^{0} = 4

なのであってますね!逆に y偏微分を求めるには,

let x = Dual{var: 0f32, eps: 0f32}; // epsを0にする!
let y = Dual{var: 2f32, eps: 1f32}; 

こうすればOKですね.

1度の演算で1つの変数に対する傾きしか求められないのが残念ですね.

全ての変数に対して1度の演算で勾配を求められるのがBackward modeなのですが, ちょっとまだ理解できてません. と言うか仕組みはわかるけど実装ができない・・・. もうちょっと勉強してみます.

一応全コード載せとく

use std::ops::{Add, Sub, Mul, Div};
use std::f32;

#[derive(Debug, Copy, Clone)]
struct Dual {
    var: f32,
    eps: f32,
}

// + 演算子のオーバーロード
impl Add for Dual {
    type Output = Dual;
    fn add(self, r: Dual) -> Dual {
        Dual {
            var: self.var + r.var,
            //ある変数の和に対する微小量の和を定義する.
            eps: self.eps + r.eps
        }
    }
}

impl Sub for Dual {
    type Output = Dual;
    fn sub(self, r: Dual) -> Dual {
        Dual {
            var: self.var - r.var,
            eps: self.eps - r.eps
        }
    }
}

// 積のオーバーロード
impl Mul for Dual {
    type Output = Dual;
    fn mul(self, r: Dual) -> Dual {
        Dual {
            var: self.var * r.var,
            // ある変数の積に対する微小量の積を定義する.
            eps: self.eps*r.var + self.var*r.eps
        }
    }
}

impl Div for Dual {
    type Output = Dual;
    fn div(self, r: Dual) -> Dual {
        Dual {
            var: self.var / r.var,
            eps: self.eps/r.var - r.eps*self.var/r.var/r.var
        }
    }
}

impl Dual {
    fn sin(self) -> Dual {
        Dual {
            var: self.var.sin(),
            eps: self.eps*self.var.cos()
        }
    }

    fn cos(self) -> Dual {
        Dual {
            var: self.var.cos(),
            eps: -self.eps*self.var.sin()
        }
    }

    fn tan(self) -> Dual {
        Dual {
            var: self.var.tan(),
            eps: self.eps/(self.var.cos()*self.var.cos())
        }
    }

    fn exp(self) -> Dual {
        Dual {
            var: self.var.exp(),
            eps: self.eps*self.var.exp()
        }
    }
}

fn newton_sqrt(var: Dual) -> Dual {
    let mut y = Dual{var:2f32, eps:0f32};
    let two = Dual{var:2f32, eps:0f32};
    for i in 0..10 {
        y = (y + var/y) / two;
        println!("{:?}", y);
    }
    y
}

// 式の定義
fn example1(x: Dual) -> Dual {
    x*x + x
}

fn example2(x: Dual) -> Dual {
    x.sin() + x*x.exp()
}

fn example3(x: Dual, y: Dual) -> Dual {
    y.sin() + x*y + y*x.exp()
}

fn main(){
    // x = 0 の傾きを求めてみる
    let x = Dual{var: 0f32, eps: 1f32};
    let y = Dual{var: 2f32, eps: 0f32};

    //println!("{:?}", example2(x));
    println!("{:?}", example3(x, y));
    //println!("{:?}", newton_sqrt(var));
}

まとめ

1年前に僕の尊敬する師匠から自動微分の話を聞いたのですが, 当時は全く意味がわかりませんでした. そもそもコンピュータがどのように微分しているかすら考えたことなかったですからね.

Backward modeも頑張って実装したいところ・・・.

以上です.