【PyTorch入門】テンソルとは?初心者が知るべき基本を解説

PyTorch

テンソルは多次元配列を表現するためのデータ構造で、PyTorchでは基本的にテンソル型としてデータを扱うことになります。
NumPyの配列と似ていて混同しやすいので、もしNumPyについて知らなければ【機械学習入門】NumPyの基礎でNumPy配列について知っておくとスムーズです。

テンソルとは、スカラーとの違い

テンソルは、多次元配列を表現するためのデータ構造です。例えばベクトル、行列などといったデータ構造がテンソル型にあたります。テンソルは、機械学習や深層学習の分野で、データを表現するために広く用いられています。
多次元配列なので、単なる数値はテンソルではありません、スカラーといいます。テンソルはスカラーによって構成されます。

PyTorchのテンソルは、GPU上での計算や効率的な学習を行うための特徴を持った多次元配列のデータ構造となります。

テンソルの作成

PyTorchのテンソルは、torch.Tensor()を使って作成することができます。たとえば、次のようにテンソルを作成することができます。

import torch

# 1次元テンソルを作成
a = torch.tensor([1, 2, 3])

# 2次元テンソルを作成
b = torch.tensor([[1, 2], [3, 4]])

NumPyの配列もテンソルデータの一種です。torch.from_numpyでPyTorchのテンソルに変換することも、逆にNumPyの配列に変換することもできます。

import numpy as np
import torch

# NumPy配列を作成
arr = np.array([1, 2, 3])

# NumPy配列からテンソルを作成
a = torch.from_numpy(arr)

# テンソルをNumPy配列に変換
arr = a.numpy()

その他にも特殊な特徴を持ったテンソルを作成することもできます。

import torch

# 3行2列のゼロで初期化された2次元テンソルを作成
a = torch.zeros(3, 2)

# 3行2列の1で初期化された2次元テンソルを作成
b = torch.ones(3, 2)

# 3行2列のランダムな値で初期化された2次元テンソルを作成
c = torch.rand(3, 2)

テンソルを作成した後は、各種演算を行うことができます。

import torch

# 2次元テンソルを作成
a = torch.tensor([[1, 2], [3, 4]])

# 2次元テンソルを加算
b = a + a
print(b)
# tensor([[2, 4],
#         [6, 8]])

# 2次元テンソルの要素ごとの乗算
c = a * a
print(c)
# tensor([[ 1,  4],
#         [ 9, 16]])

# 2次元テンソルの転置
d = a.transpose(0, 1)
print(d)
# tensor([[1, 3],
#         [2, 4]])

テンソルの形状操作

NumPy配列と同様にPyTorchのテンソルも配列の形状を変えることができます。形状変更の前後で要素数が変わるとエラーになることに注意してください。
形状を変えるには以下のようにreshapeやview関数を使用します。

import torch

# 2次元テンソルを作成
a = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 形状を3行2列に変更
b = a.reshape(3, 2)

print(b)
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

# 2次元テンソルを作成
a = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 形状を3行2列に変更
b = a.view(3, 2)

print(b)
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

shapeで形状を、numelで要素数を確認することができます。

import torch

# 2次元テンソルを作成
a = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 形状を確認
print(a.shape)  # torch.Size([2, 3])

# 要素数を確認
print(a.numel())  # 6

テンソルのデータ型

PyTorchのテンソルには、様々なデータ型があります。異なるデータ型のテンソル同士を計算することはでき無いので、テンソルのデータ型は意識しておく必要があります。

import torch

# float型のテンソルを作成
a = torch.tensor([1.0, 2.0, 3.0])

# int型のテンソルを作成
b = torch.tensor([4, 5, 6], dtype=torch.int)

# 加算を実行
c = a + b  # エラー

デフォルトのデータ型はtorch.float32ですが、dtype引数を使用して、他のデータ型を指定することができます。

import torch

# int型のテンソルを作成
a = torch.tensor([1, 2, 3], dtype=torch.int)

print(a)
# tensor([1, 2, 3], dtype=torch.int32)
タイトルとURLをコピーしました