この記事では画像に対する機械学習の手法の1つSegmentationの実装方法について、具体例を用いて紹介していきます。
機械学習における画像処理は大きく次の3種類に分類されます。
今回はこのうちのSegmentationの実装方法について紹介していきます。
画像分類に関してはこちらの記事を参照ください。
Segmentationでは、特徴を学習させた機械学習モデルに、所望の画像を入力することで領域を抽出した画像(マスク画像)が得られます。以下の画像は自動車の特徴を学習させSegmentationを行なった例です。
Segmentationは自動運転技術や衛星画像の解析、レーダー波の解析など様々な分野で応用されています。
PytorchではSegmentationを簡易に利用できるオープンソース"Segmentation-Models-Pytorch"(以降SMPと略します)があるので、これを用いたSegmentationの手順を紹介していきます。
また、この記事内のコードはSegmentation-Models-Pytorchにあるサンプルコードを基本的に使用しています。サンプルコードはnotebook形式で作成されているため、GPU環境がない方はgoogle colaboratoryを活用することで簡単に再現可能です。
Segmentation Models Pytorch
Segmentation Models Pytorchは、名前の通りPytorch用のSegmentationライブラリです。このライブラリを用いることで様々なアーキテクチャ(例えばUnet、Feature Pyramid Networkなど)や、バックボーン(VGG, ResNet, EfficientNetなど)のSegmentationモデルを簡易に実装できます。
- SMPの特徴
- 数行のコードでNeural Networkを構築可能な高レベルAPI
- マルチクラスSegmentationに対応した様々なアーキテクチャ
- 113種類の事前学習済みバックボーンネットワーク
Segmentation Models Pytorchのインストール
pipに対応しているためインストールは次のコマンドで簡単にできます。
$ pip install -U segmentation-models-pytorch
もしくはソースからインストールする場合
$ pip install -U git+https://github.com/qubvel/segmentation_models.pytorch
インストール済みのパッケージ一覧を確認するpip freezeコマンドで、インストールされたか確認します。
$ pip freeze | grep segmentation
segmentation-models-pytorch==0.2.0
余談ですが、google colaboratoryを使用する場合は、サンプルコードが使用しているライブラリalbumentationのImportでエラーする可能性があります。
その場合は、colaboratory上でインストールし直すことで対処できました。
$ pip install -U git+https://github.com/albu/albumentations --no-cache-dir
データのダウンロード
冒頭でも記述しましたが、SMPにはいくつかのサンプルコードが用意されています。
この記事では、このうち「Training model for cars segmentation on CamVid dataset 」を使ってSegmentationの方法を紹介していきますので、使用するデータをダウンロードします。
これでインストールが完了したので、具体的な使い方を見ていきましょう。
Segmentationモデルの構築
実際に機械学習モデルを構築し、Segmentationを行なっていきます。Segmentationの結果として得られるマスク画像の生成までは、次の流れに沿って行います。
- 事前準備
- ライブラリのImport、データのダウンロード、関数の定義
- データセットの作成
- 生データから学習及び推論のデータセットを作成
- データローダーの作成
- データローダーによって、データセットから順次データをロードする
- モデル学習
- タスクに応じた機械学習アーキテクチャを構築し、データセットを用いモデルを学習
- 新規データの推論
- 学習済みモデルを用い新規のデータセットを推論
事前準備
まず必要なライブラリをImportしていきます。
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import segmentation_models_pytorch as smp
import albumentations as albu
次にSegmentationに用いるデータ(CamVid dataset)をダウンロードします。
DATA_DIR = './data/CamVid/'
# load repo with data if it is not exists
if not os.path.exists(DATA_DIR):
print('Loading data...')
os.system('git clone https://github.com/alexgkendall/SegNet-Tutorial ./data')
print('Done!')
x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'trainannot')
x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'valannot')
x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'testannot')
最後に、サンプルコードで使用する関数を定義していきます。
def get_training_augmentation():
train_transform = [
albu.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0),
albu.RandomCrop(height=320, width=320, always_apply=True),
]
return albu.Compose(train_transform)
def get_validation_augmentation():
"""画像のshapeが32で割り切れるようにPaddingするための関数"""
test_transform = [
albu.PadIfNeeded(384, 480)
]
return albu.Compose(test_transform)
def to_tensor(x, **kwargs):
return x.transpose(2, 0, 1).astype('float32')
def get_preprocessing(preprocessing_fn):
"""Construct preprocessing transform
Args:
preprocessing_fn (callbale): data normalization function
(can be specific for each pretrained neural network)
Return:
transform: albumentations.Compose
"""
_transform = [
albu.Lambda(image=preprocessing_fn),
albu.Lambda(image=to_tensor, mask=to_tensor),
]
return albu.Compose(_transform)
# 可視化用の関数
def visualize(**images):
"""PLot images in one row."""
n = len(images)
plt.figure(figsize=(16, 5))
for i, (name, image) in enumerate(images.items()):
plt.subplot(1, n, i + 1)
plt.xticks([])
plt.yticks([])
plt.title(' '.join(name.split('_')).title())
plt.imshow(image)
plt.show()
これで準備が完了したのでデータセットの作成に移ります。
データセットの作成
データセットを作成し、学習や推論に適した形に生データを変換します。PyTorchでは基本的にこのデータセットからデータをロードすることで学習、推論を効率的に実施できます。
データセット作成のポイントは次の通りです。
- torch.utils.data.Datasetを継承したClassとして作成する
- コンストラクタ(__init__)では、生データの格納場所を渡す
- Classには最低でも、__getitem__メソッド、__len__メソッドを作成する必要がある
- __getitem__メソッドは、学習もしくは推論データと学習する特徴(学習時のみ)を返す
- ___len__メソッドはデータセットの長さを返す
それでは、このポイントを踏まえてDataset Classを作成していきます。
コンストラクタ(__init__)では、生データが格納されているディレクトリのPath(images_dir, masks_dir)、__getitem__メソッドで利用する関数を指定します。
# 1. torch.utils.data.Datasetを継承したDataset classを定義
class Dataset(torch.utils.data.Dataset):
CLASSES = ['sky', 'building', 'pole', 'road', 'pavement',
'tree', 'signsymbol', 'fence', 'car',
'pedestrian', 'bicyclist', 'unlabelled']
def __init__(
self,
images_dir, # 画像のPath
masks_dir, # マスク画像のPath
classes=None, # 推論対象のクラス
augmentation=None, # augmentation用関数
preprocessing=None, # 前処理用関数
):
self.ids = os.listdir(images_dir)
self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
# クラス名の文字列('car', 'sky'など)をIDに変換
self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
self.augmentation = augmentation
self.preprocessing = preprocessing
__getitem__メソッドでは、コンストラクタで定義したディレクトリからデータを1つずつ取得し、必要な処理を適用した後、データセットを返します。
# 3. 学習用データ(image)と特徴(mask)を返す__getitem__メソッドを作成
def __getitem__(self, i):
# データの読み込み
image = cv2.imread(self.images_fps[i])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(self.masks_fps[i], 0)
# 学習対象のクラス(例えば、'car')のみを抽出
masks = [(mask == v) for v in self.class_values]
mask = np.stack(masks, axis=-1).astype('float')
# augmentation関数の適用
if self.augmentation:
sample = self.augmentation(image=image, mask=mask)
image, mask = sample['image'], sample['mask']
# 前処理関数の適用
if self.preprocessing:
sample = self.preprocessing(image=image, mask=mask)
image, mask = sample['image'], sample['mask']
return image, mask
__len__メソッドでは、コンストラクタで指定した生データの長さを取得します。
# 4. データセットの長さを返す__len__を作成
def __len__(self):
return len(self.ids)
これでデータセットClassが作成できました。
データローダーの作成
学習/推論ループでデータセットを読み込むためにデータローダーを作成します。
データローダーの作成は簡単です。torch.utils.data.DataLoaderにデータセットなどの必要な情報を渡せば作成できます。
from torch.utils.data import DataLoader
# データセットのインスタンスを作成
train_dataset = Dataset(
x_train_dir,
y_train_dir,
augmentation=get_training_augmentation(),
preprocessing=get_preprocessing(preprocessing_fn),
classes=CLASSES,
)
# データローダーの作成
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=12)
推論用のDataLoaderを作成する際は引数の、batch_size=1, shuffle=Falseにするようにしましょう。
それでは、これらを用いてモデルを学習していきます。
モデル学習
先ほど作成したデータセットを用いて、次のステップでモデル学習を実施していきます。
- モデルの取得
- Segmentation-models-pytorchを用いてモデルを取得
- 損失関数、最適化関数の設定
- タスクに応じた損失関数、評価関数を設定
- モデル学習
- データセットを用いモデルを学習し、保存
使用するアーキテクチャを決めてモデルを取得します。
SMPで利用可能なアーキテクチャの一覧はこちらから確認できます。
# パラメータ
ENCODER = 'se_resnext50_32x4d' # バックボーンネットワークの指定
ENCODER_WEIGHTS = 'imagenet' # 使用する学習済みモデル
ACTIVATION = 'sigmoid' # 恒等関数や、マルチクラス用に'softmax2d'にする場合は'None'へ
CLASSES = ['car']
DEVICE = 'cuda'
# SMPを用いて学習済みモデルを取得(アーキテクチャはFPN)
model = smp.FPN(
encoder_name=ENCODER,
encoder_weights=ENCODER_WEIGHTS,
classes=len(CLASSES),
activation=ACTIVATION,
)
上記のコードを実行すると、初回は学習済みモデルのダウンロードが行われた後にモデルが取得できます。取得したモデルはprint文で出力すると構成を確認できます。
print(model)
# 出力結果(抜粋)
FPN(
(encoder): SENetEncoder(
(layer0): Sequential(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
)
~~~ 略 ~~~
(segmentation_head): SegmentationHead(
(0): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1))
(1): UpsamplingBilinear2d(scale_factor=4.0, mode=bilinear)
(2): Activation(
(activation): Sigmoid()
)
)
)
続いて損失関数(loss)及び最適化関数(optimizer)を設定していきます。
# 損失関数
loss = smp.utils.losses.DiceLoss()
# 評価関数
metrics = [
smp.utils.metrics.IoU(threshold=0.5),
]
# 最適化関数
optimizer = torch.optim.Adam([
dict(params=model.parameters(), lr=0.0001),
])
ここまで来たらエポック単位の学習処理を定義し、モデルを学習します。
注意点として、google colaboratoryを用いている場合はGPUを有効にするのを忘れないようにしましょう。(「ランタイム」 -> 「ランタイムのタイプを変更」から設定できます)
# SMPに用意されているシンプルなループ関数(Train用)
train_epoch = smp.utils.train.TrainEpoch(
model,
loss=loss,
metrics=metrics,
optimizer=optimizer,
device=DEVICE,
verbose=True,
)
# SMPに用意されているシンプルなループ関数(Validation用)
valid_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=DEVICE,
verbose=True,
)
# 学習ループの実行
n_epoch = 40 # エポック数
max_score = 0
for i in range(0, n_epoch):
print('\nEpoch: {}'.format(i))
train_logs = train_epoch.run(train_loader)
valid_logs = valid_epoch.run(valid_loader)
# 評価関数の値が更新されたらモデルを保存
if max_score < valid_logs['iou_score']:
max_score = valid_logs['iou_score']
torch.save(model, './best_model.pth')
print('Model saved!')
# エポック25以降は学習率(learning rate)を下げる
if i == 25:
optimizer.param_groups[0]['lr'] = 1e-5
print('Decrease decoder learning rate to 1e-5!')
これでモデルの学習まで完了しましたので、最後に新規データに対して推論を行い自動車が抽出されるかを確認します。
新規データの推論
モデルの学習まで完了したので、新規データに対してモデルを用いて推論していきます。
推論処理でやることは今まで行ってきた内容とほとんど同じで、
- 学習モデルの読み込み
- 推論用データセット、データローダーの作成
- 推論用データセット及び学習済みモデルを用いて推論
- 推論画像の可視化
この順番で行います。
ちなみに4番の可視化は結果確認用ですので、推論をするだけでしたら3番までで完了しています。
推論に先立ち、可視化用の関数をいくつか準備しておきます。
# 可視化用の画像を取得するデータセットを作成(Augmentationなし)
test_dataset_vis = Dataset(
x_test_dir, y_test_dir,
classes=CLASSES,
)
# 可視化用の関数
def visualize(**images):
"""PLot images in one row."""
n = len(images)
plt.figure(figsize=(16, 5))
for i, (name, image) in enumerate(images.items()):
plt.subplot(1, n, i + 1)
plt.xticks([])
plt.yticks([])
plt.title(' '.join(name.split('_')).title())
plt.imshow(image)
plt.show()
準備が整ったので、推論を実施していきます。
# 1. 学習モデルの読み込み
best_model = torch.load('./best_model.pth')
# 2. 推論用のデータセット、データローダーの作成
test_dataset = Dataset(
x_test_dir,
y_test_dir,
augmentation=get_validation_augmentation(),
preprocessing=get_preprocessing(preprocessing_fn),
classes=CLASSES,
)
test_dataloader = DataLoader(test_dataset)
n_data = 2 # 確認するデータの数
for i in range(n_data):
n = np.random.choice(len(test_dataset))
# 3. 新規データの取得
image_vis = test_dataset_vis[n][0].astype('uint8')
image, gt_mask = test_dataset[n]
gt_mask = gt_mask.squeeze()
# 3. 新規データの推論
x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
pr_mask = best_model.predict(x_tensor)
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
# 4. 可視化
visualize(
image=image_vis,
ground_truth_mask=gt_mask,
predicted_mask=pr_mask
)
推論結果を可視化した画像は次の通りです。
左から、生画像、政界のマスク画像、予測した推論画像です。自動車の特徴を捉えたモデルが学習できていることが分かります。
終わりに
画像における機械学習のタスクとしては、画像分類、物体検出、セグメンテーションがあり、今回はそのうちSegmentation Models.pytorchを用い、自動車の位置を取得するたセグメンテーションの方法を紹介しました。
画像分類については以前手書き数字の画像MNISTを用いた記事で紹介しておりますので、物体検出の記事を書いたら、画像における機械学習タスクのまとめ記事を執筆予定です。
コメント