yorhaha

yorhaha

深層学習の最適化器

オプティマイザー#

バックプロパゲーション#

バックプロパゲーションの主なアイデアは、モデルパラメータの勾配を損失関数を計算し、勾配降下法を使用してパラメータを更新し、損失関数を最小化し、モデルを最適解に近づけることです。

実装プロセスには、2 つの主要なステップが含まれます:フォワードプロパゲーションとバックプロパゲーション。フォワードプロパゲーションは、ネットワークの入力が各層を通過した出力を計算します。バックプロパゲーションは、チェーンルールを使用して損失関数の各パラメータの勾配をネットワークに戻し、パラメータを更新します。この反復プロセスにより、ニューラルネットワークは徐々に入力と出力のマッピング関係を学習し、ネットワークの予測性能を向上させることができます。

バックプロパゲーションの基本的なアイデアは、損失関数をネットワーク内で逆方向に伝播させ、出力層から入力層に向かって勾配を計算し、累積します。具体的には、アルゴリズムはネットワークの最後の層から始めて、出力層の誤差勾配を計算し、その勾配を前の層に伝播させ、反復的に計算を行い、入力層に伝播させます。各層では、チェーンルールに基づいて、現在の層の勾配にその層の重みを乗算し、前の層に伝播させます。

バックプロパゲーションの手順は次のとおりです:

  1. フォワードプロパゲーション:入力データをネットワークに送り、出力結果を計算します。
  2. 損失の計算:出力結果を真のラベルと比較し、損失関数の値を計算します。
  3. バックプロパゲーション:出力層から始めて、損失関数に基づいて出力層の勾配を計算し、それから各層の勾配を前方に伝播させます。
  4. パラメータの更新:勾配降下法や他の最適化アルゴリズムを使用して、計算された勾配に基づいてネットワークのパラメータを更新します。
  5. ステップ 1 から 4 を繰り返し、停止条件(最大反復回数や損失関数の収束など)に達するまで続けます。

バッチ勾配降下法(GD)#

θt=θt1αθ1Ni=1NJ(θ;x(i))\boldsymbol{\theta}_{\boldsymbol{t}}=\boldsymbol{\theta}_{\boldsymbol{t}-\mathbf{1}}-\alpha \nabla_{\boldsymbol{\theta}} \frac{1}{N} \sum_{i=1}^{N} J\left(\boldsymbol{\theta} ; x^{(i)}\right)

NN:すべてのサンプル

確率的勾配降下法(SGD)#

θt=θt1αθ1BSi=1BSJ(θ;x(i))\boldsymbol{\theta}_{\boldsymbol{t}}=\boldsymbol{\theta}_{\boldsymbol{t}-\mathbf{1}}-\alpha \nabla_{\boldsymbol{\theta}} \frac{1}{\text{BS}} \sum_{i=1}^{\text{BS}} J\left(\boldsymbol{\theta} ; x^{(i)}\right)

BS\text{BS}:ミニバッチ

Adagrad#

vt=vt1+gt2θt=θt1αgtvt+ϵv_t=v_{t-1}+g_t^2\quad \theta_t=\theta_{t-1}-\alpha\frac{g_t}{\sqrt{v_t+\epsilon}}

RMSProp#

vt=γvt1+(1γ)gt2θt=θt1αgtvt+ϵv_t=\gamma v_{t-1}+(1-\gamma)g_t^2 \quad \theta_t=\theta_{t-1}-\alpha\frac{g_t}{\sqrt{v_t+\epsilon}}

Adam#

mt=β1mt1+(1β1)gtvt=β2vt1+(1β2)gt2θt=θt1αmtvt+ϵm_t=\beta_1 m_{t-1} + (1-\beta_1)g_t\qquad v_t=\beta_2 v_{t-1}+(1-\beta_2)g_t^2\qquad \theta_t=\theta_{t-1}-\alpha\frac{m_t}{\sqrt{v_t+\epsilon}}

ϵ=109,β1=0.9,β2=0.999\epsilon=10^{-9},\beta_1=0.9, \beta_2=0.999

ウォームアップ:

m^t=mt/(1β1t)v^t=vt/(1β2t)\hat{m}_t=m_t/(1-\beta_1^t)\qquad \hat{v}_t=v_t/(1-\beta_2^t)
読み込み中...
文章は、創作者によって署名され、ブロックチェーンに安全に保存されています。