【PyTorch入門】モデル作成のための必須知識を徹底解説 -誤差逆伝播法をやっつける-

PyTorch

【PyTorch入門】テンソルとは?初心者が知るべき基本を解説ではテンソルの基本的な扱い方を学びました。

作成したモデルを学習する方法の基礎として誤差逆伝播法があります。ニューラルネットワークを学ぶ上で必須の基礎理論ですが、多くの方が理解に苦しむAI学習最初の壁です。

本記事では無理なく誤差逆伝播法を学べるように必要な知識を順番に解説しています。以前挫折したことがある人もぜひ本記事で再チャレンジしてみてください。

計算グラフ

PyTorchは、計算グラフを使用して、ニューラルネットワークを構築します。PyTorchは、計算グラフを自動的に構築し、ニューラルネットワークの演算を効率的に実行します。

PyTorchで簡単な計算グラフを作成する例です。入力をx、重みのパラメータをwとします。requires_grad=True はPyTorchがパラメータの勾配 (後述) を自動計算するための指示となります。

import torch

x = torch.tensor([2.0])
w = torch.tensor([0.5], requires_grad=True)
y = w * x
z = 2 * y + 3

この計算グラフは、入力変数xと重み変数wを掛け合わせたyを、2倍して3を足したzを計算するモデルを表しています。
ニューラルネットワークでは勾配というものを計算してモデルを学習します。その勾配の自動計算にPyTorchはグラフを必要としますが、グラフはPyTorchが自動で作成するので意識することはあまりないかも知れません。

勾配

勾配とは一言でいうと関数の傾きです。

例えば「y = ax^2 + 2」というモデル関数があるとします。y が出力、xが入力、aが重みのパラメータです。現状 a = 3 となっていて、「y = 3x^2 + 2」というモデルになっているとします。そして、入力が 2 の場合に 4 を出力することを求められているとしましょう。しかし現状のモデルだと x = 2 だと y = 14 です。 a = 3 という現状のパラメータの値は誤りといえます。

この誤りを訂正するために、ニューラルネットワークでは本来の正解と現状の値との差 (損失) を最小化するように訓練してパラメータを更新していきます。
上記の例だと「数学で簡単に解ける」と思うかも知れませんが、ニューラルネットワークのモデル関数は非常に複雑で多数のパラメータを持っているので、数学的に最適なパラメータを一発で求めることは事実上不可能です。そのため、勾配の計算を行い学習に利用します。

「y = ax^2 + 2」のパラメータ a の勾配は 2x です。 x = 2 のときの a の勾配は4となります。つまり、a が増えると y は増える、 a が減ると y も減るという関係になっています。
現状 y = 14 に対して求められている出力が 4 なので、a の値を小さくすると正しいモデルに近づきます。
このようにニューラルネットワークでは勾配を利用して少しずつ正しい出力をするモデルになるように学習していきます。

先ほどの計算グラフのパラメータ w の勾配を考えてみましょう。微分が分からないという方は簡単に読み流してください。

y = w * x

yノードの勾配を計算するためには、傾きを求めればいいので y を x で微分します。y = w * xの微分はdy/dx = wです。よって、yノードの勾配は w となります。

z = 2 * y + 3

zノードの勾配を計算するためには、zをyで微分する必要があります。z = 2 * y + 3の微分はdz/dy = 2です。よって、zノードの勾配は 2 となります。

以上から、z に対する w の勾配は次のようになります。

dz/dw = dz/dy * dy/dw
      = 2 * x
      = 2x

x = 2.0 なので、wの勾配は4になります。

損失関数

損失関数とは、機械学習モデルの性能を評価するために用いられる関数のことです。損失関数は、モデルの出力と正解の差を表す指標で、この値が小さいほどモデルの性能が良いと評価されます。
まわりくどい気がしますが、複雑で多量のパラメータを持つニューラルネットワークを学習するにはそのような評価関数を通してモデルを学習していくことになります。

例えば、回帰問題 (入力の値から出力の値を予測する問題) では、モデルが予測した値と正解の値の差を二乗しその平均を計算した、平均二乗誤差(Mean Squared Error, MSE)をよく使用します。
例として構築した計算グラフで x = 2 、 w = 0.5 の場合、モデル出力は 5 になります [2 * (0.5 * 2) + 3] 。 正解の出力が 1 だとすると平均二乗誤差は (5 – 1)^2 で 16 になります。

以上の損失関数を考慮したグラフを作成してみます。

import torch

x = torch.tensor([2.0])
w = torch.tensor([0.5], requires_grad=True)
y = w * x
z = 2 * y + 3

# 正解ラベルを定義する
target = torch.tensor([1.0])

# 損失関数を定義する(平均二乗誤差)
criterion = torch.nn.MSELoss()

# 平均二乗誤差を計算する
loss = criterion(z, target)  # tensor(16., grad_fn=<MseLossBackward0>)

先ほどz に対する w の勾配を求めましたが、損失関数に対する w の勾配を求めましょう。
loss = (z – target)^2 なので、dloss/dz = 2z – 2target です。 つまり

dloss/dw = dloss/dz * dz/dy * dy/dw
         = 2(2 - target) * 2 * x
         = 4x(2 - target)
         = 4*2(2 - 1)
         = 8

損失関数に対する w の勾配は 8 と求まりました。勾配がプラスなので、wの値を減らすと損失関数が減少することが分かります。損失関数が減少することでモデルが正しい結果を出力する方向に一歩近づきます。

モデルの学習を行うとは、このように損失誤差の値を小さくするようにモデルのパラメータを更新することといえます。

誤差逆伝播法とは

今まで見た例ではパラメータは w ひとつでしたが、2つ以上になるとどのようにモデルを学習させれば良いでしょうか。パラメータ a を追加したグラフを構築してみましょう。

import torch

x = torch.tensor([2.0])
w = torch.tensor([0.5], requires_grad=True)
a = torch.tensor([1.0], requires_grad=True)  # 新たなパラメータを追加
y = w * x + a  # ここでパラメータaを計算に含める
z = 2 * y + 3

target = torch.tensor([1.0])

criterion = torch.nn.MSELoss()

loss = criterion(z, target)  # tensor(36., grad_fn=<MseLossBackward0>)

パラメータ a を追加することで損失関数の値が 36 になっています。このモデルには2つのパラメータが存在するので、それぞれのパラメータをどのように変化させれば損失関数が減るかを考えます。
やることは w のみの時と変わりません。それぞれ損失関数との勾配を求めればパラメータを動かすべき方向がわかります。

dloss/dw = dloss/dz * dz/dy * dy/dw
         = 2(2 - target) * 2 * x
         = 4x(2 - target)
         = 4*2(2 - 1)
         = 8

dloss/da = dloss/dz * dz/dy * dy/da
         = 2(2 - target) * 2 * 1
         = 4(2 - target)
         = 4(2 - 1)
         = 4

それぞれのパラメータの勾配はプラスになるので、それぞれのパラメータを減らすことで損失関数も減少しこのモデルの学習は進みます。

さて、この勾配の計算、dloss/dwとdloss/daを1から真面目にすることはないということに気づかれましたでしょうか?

グラフの演算の一番最後の部分、dloss/dz部分をまず計算します。dloss/dwとdloss/daの両方でこの微分式が使われているので、一度計算した結果を流用することができます。
続けて、dz/dyを計算します。こちらも両者で使われているので一度計算した結果を使いまわせます。その後dy/dwとdy/daを計算すれば勾配の計算ができます。

以上のようにグラフの後ろの微分値ほど使い回しができるという性質から、勾配の計算は後ろから順番に微分していくと効率的です。これを逆伝播といいます。そして、このように逆伝播を使って損失関数の値を最小化するように学習することを誤差逆伝播法といいます
反対に損失関数 (loss) の値を求めるときのように、グラフの演算を順番に実行することを順伝播といいます。

誤差逆伝播法について理解できましたでしょうか?ニューラルネットワークを学習する上では避けては通れない理論となります。実際の勾配の計算はPyTorchが行うので、計算をできる必要はありませんが、考え方は定着させておくと良いでしょう。

タイトルとURLをコピーしました