チュートリアル:独自のロス関数を利用する

Tuesday, September 04, 2018

App , Cloud

Posted by Yoshiyuki Kobayashi

Neural Network Consoleには、SquaredError、BinaryCrossEntropy、CategoricalCrossEntropyなど、基本的なロス関数が予めレイヤーとして用意されています。
一方で、課題によっては独自のロス関数を用いた最適化が必要になることも多々あります。

本チュートリアルでは、Neural Network Console上で予め用意されていないロス関数を自分で定義し、学習に用いる方法について解説します。

 

1. ロスとして扱われる値はどれ?

ロス関数の定義方法について解説する前に、Neural Network Consoleにおけるロス関数の扱いについて復習します。

Neural Network Consoleは、詳細設定タブ、Optimizerのネットワークで指定されたネットワークにおいて、末端のレイヤーの出力値を各レイヤーで平均し、それらを合計した値を最小化すべきロスとします。

例えば、以下の複数の末端レイヤーを持つネットワークAと、それらの末端のレイヤーの平均とその合計を算出して1つにまとめたネットワークBは、学習において全く同じ挙動を示します。


ネットワークA


ネットワークB

以上の仕組みを利用し、ネットワークの末端にロスの値が現れるようにネットワークを設計することで、様々な独自のロス関数を扱うことができます。

 

2. Squared Error(二乗誤差)ロスを定義する

ここでは、仮にNeural Network ConsoleにSquaredErrorレイヤーが存在しなかった場合を想定し、SquaredErrorレイヤーを用いず自分で二乗誤差ロスを定義してみることにします。題材としては、06_auto_encoderのサンプルプロジェクトを用います。このサンプルプロジェクトは、入力画像を低次元に圧縮した後元の画像を再構成するオートエンコーダーを学習するもので、再構成画像と入力画像との二乗誤差を最小化するためにSquaredErrorレイヤーが用いられています。


06_auto_encoderサンプルプロジェクト

さて、二乗誤差ロスは以下の式で表されるように、2つの入力について要素毎に二乗を計算するものです。

独自のロス関数を定義するには、この式を元にMathレイヤー、Arithmetic(算術演算)レイヤーなどを用いてロス関数を記述していきます。06_auto_encoderにおいては、x(0)は再構成画像、x(1)は入力画像になりますので、Squared Errorは以下のように表現することができます。


06_auto_encoderサンプルのSquaredErrorロスを独自に定義した例

ここで、Sub2レイヤーはx(0)とx(1)の要素毎の差分を、PowScalarレイヤーは要素毎の二乗を計算するために用いられています。

このように、Neural Network ConsoleではMathレイヤー、Arithmetic(算術演算)レイヤーなどを用いてロス関数の数式を記述していくことで、予め用意されている以外にも様々なロス関数を扱うことができます。

 

3. ロス関数を独自に定義した際の注意点

ロス関数を独自に定義した場合、推論用のネットワークを手動で定義しなければならない場合があります。

Neural Network Consoleは、Mainネットワーク(名前がMainのネットワーク)で定義したネットワーク構造を元に、学習中の評価用ネットワーク(MainValidation)、推論用ネットワーク(MainRuntime)を自動作成します。この際、推論用のMainRuntimeネットワークは、Mainネットワークに含まれるあらかじめ用意されたロス関数の手前の値を最終的な出力とするように構成されます。このため、独自にロス関数を定義し、Mainネットワークにあらかじめ用意されたロス関数が全く含まれない場合、MainRuntimeネットワークは何も出力しないネットワークになってしまいます。

このような場合は、推論用のネットワークを手動で設計、指定することで、正しく推論を実行できるようになります。推論用のネットワークを手動で設計、指定する手順は以下の通りです。

 

  1. 編集タブ、Mainネットワークタブ右の+ボタンをクリックして新しいネットワークを追加
  2. ネットワーク名(Network_2)をクリックしてRuntimeなどにリネーム
  3. Mainネットワークの内容を全選択(Ctrl+A)してコピー(Ctrl+C)
  4. RuntimeネットワークにコピーしたMainネットワークを貼り付け(Ctrl+V)
  5. Runtimeネットワークで、ロス関数に相当する箇所を削除
  6. もしBatchNormalizationを利用している場合、Runtimeネットワークの全てのBatchNormalizationのBatchStatプロパティをFalseに設定
  7. Runtimeネットワークの最後(上記06_auto_encoderの場合はSigmoid_2)にIdentityレイヤーを接続し、Nameプロパティを推論時の出力変数名(任意)にリネーム
  8. 詳細設定タブ、Executorのネットワークを「MainRuntime」から「Runtime」に変更