【PyTorch入門】データセットの使い方の超基本 -DataLoaderの使い方、標準化とは?-

PyTorch

画像データを入力して、その画像が「猫」であるかを判断するモデルを構築しても、学習させなければうまくモデルは判断することができません。なので、学習に使うための大量の画像が必要になります。

そこで、本記事では、CIFAR-10を使用してデータセットをモデルの学習に利用する方法を説明します。CIFAR-10は、画像認識タスクで広く使用されるデータセットの一つです。これをダウンロードして活用できるように順序立てて解説していきます。

CIFAR-10とは?

CIFAR-10は、32×32ピクセルのカラー画像から構成されるデータセットです。飛行機、自動車、鳥、猫、鹿、犬、カエル、馬、船、トラックの画像が含まれています。この画像を猫とそれ以外に分け、モデルを学習させます。

それではCIFAR-10データセットの中身を確認してみましょう。もしmatplotlibがインストールされていない場合は、事前に以下のコマンドを実行してください。

pip install matplotlib

ライブラリのインポート

まず、コードで必要となるライブラリをインポートします。

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split
import matplotlib.pyplot as plt
import numpy as np

torchvision: 画像処理用のライブラリ、データセットのロードや画像変換などに使う
transforms: 画像の前処理を行うためのツール
random_split: データセットの分割に使用(後述)
matplotlib.pyplot: 画像を表示するために使うライブラリ

CIFAR-10データのダウンロードと前処理

次にCIFAR-10の画像データをダウンロードし、機械学習モデルで使用するための前処理を行います。

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
loader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True, num_workers=0)

transform: transforms.Composeで実行するデータ処理内容を定義しています。transforms.ToTensor()は入力データ(ここでは画像データ)をPytorchで処理可能なテンソル形式に変換します。
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))は画像の各チャンネル(赤、緑、青)のピクセル値に対して、指定された平均(0.5)と標準偏差(0.5)を用いて標準化を行います。
これらの詳しい処理については後述しますが、難しければひとまず「画像のピクセルデータをディープラーニングで学習しやすい値に変形している」という理解で構いません。
trainset: CIFAR-10データセットをダウンロードし、transform=transformで、上記のtransformで定義したデータ変換処理を適用します。
loader: データセットからデータをバッチサイズ32でロードします。バッチサイズはbatch_size=32の部分で指定していますが、詳しくは後述します。

※transforms.Composeで登録された処理について

transforms.ToTensor()には役割が大きく2つあります。
1つは、画像データをPyTorchテンソルに変換すること。
2つ目は、データの値を0〜1の間になるように変換することです。これを「正規化といいます。今回は0〜255の間の値を取るピクセルの値を255で割り、0〜1の間になるように変換されます。

transforms.Normalize(mean, std) は、データセットの平均(mean)と標準偏差(std)を使用してデータを標準化します。

標準化」とは、データセット内の特徴量を一定の平均と標準偏差に基づいて変換するプロセスです。異なる特徴量のスケールを統一し、特徴量間の比較や組み合わせを容易にすることで、学習の効率と性能を向上させることができます。

例えば今回の例だと、各ピクセルの値は平均(0.5)と標準偏差(0.5)を使用して以下のようにデータを標準化することになります。本来は平均と標準偏差はデータセットから計算されるべきですが、今回は推測値として0.5で代用しています。

1. 各ピクセル値から0.5(指定された平均)を引く(各ピクセル値は-0.5〜+0.5になる)
2. その結果を0.5(指定された標準偏差)で割る(これにより標準偏差が1になる)

以上の操作により、データの範囲は -1〜+1 になり0を基準とする標準偏差1の分布を取るようになります。

検証用データとの分離

ダウンロードした画像データを全て学習に使ってしまうと正常に学習できたか評価ができません。
ですので、PyTorchの random_split 関数を使って、ダウンロードしたCIFAR-10データセットを学習用データセット(トレーニングセット)と検証用データセット(バリデーションセット)に分割します。

# トレーニングセットの全サイズを取得
total_size = len(trainset)

# トレーニングセットとバリデーションセットの割合を定義(例:トレーニング80%、バリデーション20%)
train_size = int(total_size * 0.8)
val_size = total_size - train_size

# トレーニングセットとバリデーションセットに分割
train_dataset, val_dataset = random_split(trainset, [train_size, val_size])

# トレーニングセットとバリデーションセット用のDataLoaderを作成
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
valloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

以上のようにDataLoaderを訓練用と検証用に分割することでモデルの性能を客観的に測定することができます。

画像の表示

実際にデータを取得し、画像として表示するコードを作成しましょう。以下のコードでは、CIFAR-10データセットから取得した画像を格子状に表示します。

dataiter = iter(trainloader)
images, labels = next(dataiter)

first_image = images[0]
print(first_image.shape)
print(first_image)

def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy() # テンソルからnupmyへ変換
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

imshow(torchvision.utils.make_grid(images))
  • dataiter = iter(trainloader)

トレーニングデータセットのイテレータを作成しています。イテレータは、データセットなどの要素の集まりを順番に一つずつ処理するためのオブジェクトです。
iter() 関数は、trainloaderからCIFAR-10データセットのイテレータを作成します。

  • images, labels = next(dataiter)

next() 関数はイテレータ(dataiter)から次の要素を取り出すために使用します。trainloaderを作成する際にbatch_size=32と指定したので、一度に32件分の画像とラベルデータを取得します。
今回はdataiterを作成して1度目にnext()実行しましたので、先頭から32件分のデータを取得します。仮にもう一度next()を実行すると、33件目から32件分のデータを取得することになります。

  • first_image = images[0]
    print(first_image.shape)
    print(first_image)

first_image は、CIFAR-10データセットから取り出された最初の画像データです。print(first_image.shape)の実行結果は「print(first_image.shape)」となります。これは画像データなので構造が「チャンネル(RGB), 高さ, 幅」の3次元テンソルとなるからです。
print(first_image)でデータの内容を表示しています。imagesはtransformによりピクセルデータが前処理された状態で取得されるので、first_imageは-1〜+1の間の値となっていることが分かります。

  • torchvision.utils.make_grid(images)

make_grid関数は画像データのリストをPyTorchのテンソル形式で受け取り、画像を格子状に並べ替えて、一つの大きな画像データをテンソル形式で返します。
つまり、36枚のテンソル形式の画像データを1枚のテンソル形式の画像データに変換しています。

  • def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy() # テンソルからnupmyへ変換
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

imshow関数は、画像データをテンソル形式で受け取り、画像を表示する関数です。
「img = img / 2 + 0.5」はtransforms.Normalizeと逆の処理を行い、標準化処理を解除して元のデータセットのスケールに戻しています。
PyTorchテンソルでは、画像の形式は (チャンネル, 高さ, 幅) となっていますが、matplotlibで画像を表示するためには (高さ, 幅, チャンネル) の形式に変換する必要があるので、np.transpose() 関数を使って軸の入れ替えを行なっています。
(チャンネル, 高さ, 幅)のチャンネルを0、高さを1、幅を2とし、「np.transpose(npimg, (1, 2, 0))」は(1, 2, 0)の順番に軸を入れ替えます。よって結果として、(高さ, 幅, チャンネル)の順番になります。
以上でmatplotlibが表示できるように値を整えたので、plt.imshow()とplt.show()で実際に画像を表示します。

ラベルの表示

labelsはその画像データが何の画像かを示すデータです。0は飛行機、1は車、2は鳥、3は猫というように決まっています。

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(32)))

画像が「猫」であるかを判断するモデルを構築する場合、ラベルは猫であるかが判断できれば良いので、猫を1、それ以外を0と変形します。

# ラベルが猫の場合は1、それ以外の場合は0に設定
binary_labels = (labels == 3).type(torch.float)

# 結果を確認
print("Original labels: ", labels)
print("Binary labels: ", binary_labels)

「binary_labels = (labels == 3).type(torch.long)」について説明します。labels == 3により、要素ごとに比較を行い、3の値の場合はtrueに、そうでない場合はfalseで返します。次に真偽値を数値の1, 0に変換するため、テンソルの型をtype(torch.float)で数値型に変換しています。

モデルの学習への利用

モデルの学習でデータセットを利用する際は、学習用の入力データと正解ラベル(猫の画像かどうか)の組み合わせを取得し学習させるという繰り返しを行うことになります。

そこで、trainloaderからデータを順番にデータを取得し、学習に使用するコードをご紹介します。

for i, data in enumerate(trainloader, 0):
    inputs, labels = data
    print(f"No.{i} 1件目のラベル:{labels[0]}")
    labels = (labels == 3).type(torch.float)
    # モデル学習の部分

このように既存のデータセットを使えば、自分でデータを集めることなく構築したモデルの学習や評価を進めることができます。Pytorchの学習を進めるにあたり必須となりますので、扱い方を覚えておくと良いでしょう。

タイトルとURLをコピーしました