【画像処理】PyTorchでSegmentationを実装する方法【機械学習】

segmentation_title PyTorch

この記事では画像に対する機械学習の手法の1つSegmentationの実装方法について、具体例を用いて紹介していきます。

この記事で学べる内容
  • PyTorchを使用したSegmentationの実装方法

機械学習における画像処理は大きく次の3種類に分類されます。

  • 画像分類 何の画像かを分類する
  • 物体検出 画像のどこに何が映っているかを検出する
  • Segmentation 領域単位で何を表しているかを分類する

今回はこのうちのSegmentationの実装方法について紹介していきます。
画像分類に関してはこちらの記事を参照ください。

Segmentationでは、特徴を学習させた機械学習モデルに、所望の画像を入力することで領域を抽出した画像(マスク画像)が得られます。以下の画像は自動車の特徴を学習させSegmentationを行なった例です。

Segmentationは自動運転技術や衛星画像の解析、レーダー波の解析など様々な分野で応用されています。
PytorchではSegmentationを簡易に利用できるオープンソース"Segmentation-Models-Pytorch"(以降SMPと略します)があるので、これを用いたSegmentationの手順を紹介していきます。
また、この記事内のコードはSegmentation-Models-Pytorchにあるサンプルコードを基本的に使用しています。サンプルコードはnotebook形式で作成されているため、GPU環境がない方はgoogle colaboratoryを活用することで簡単に再現可能です。

Segmentation Models Pytorch

Segmentation Models Pytorchは、名前の通りPytorch用のSegmentationライブラリです。このライブラリを用いることで様々なアーキテクチャ(例えばUnet、Feature Pyramid Networkなど)や、バックボーン(VGG, ResNet, EfficientNetなど)のSegmentationモデルを簡易に実装できます。

  • SMPの特徴
    • 数行のコードでNeural Networkを構築可能な高レベルAPI
    • マルチクラスSegmentationに対応した様々なアーキテクチャ
    • 113種類の事前学習済みバックボーンネットワーク

Segmentation Models Pytorchのインストール

pipに対応しているためインストールは次のコマンドで簡単にできます。

$ pip install -U segmentation-models-pytorch

もしくはソースからインストールする場合
$ pip install -U git+https://github.com/qubvel/segmentation_models.pytorch

インストール済みのパッケージ一覧を確認するpip freezeコマンドで、インストールされたか確認します。

$ pip freeze | grep segmentation
segmentation-models-pytorch==0.2.0

余談ですが、google colaboratoryを使用する場合は、サンプルコードが使用しているライブラリalbumentationのImportでエラーする可能性があります。
その場合は、colaboratory上でインストールし直すことで対処できました。

$ pip install -U git+https://github.com/albu/albumentations --no-cache-dir

データのダウンロード

冒頭でも記述しましたが、SMPにはいくつかのサンプルコードが用意されています。
この記事では、このうち「Training model for cars segmentation on CamVid dataset 」を使ってSegmentationの方法を紹介していきますので、使用するデータをダウンロードします。
これでインストールが完了したので、具体的な使い方を見ていきましょう。

Segmentationモデルの構築

実際に機械学習モデルを構築し、Segmentationを行なっていきます。Segmentationの結果として得られるマスク画像の生成までは、次の流れに沿って行います。

  1. 事前準備
    • ライブラリのImport、データのダウンロード、関数の定義
  2. データセットの作成
    • 生データから学習及び推論のデータセットを作成
  3. データローダーの作成
    • データローダーによって、データセットから順次データをロードする
  4. モデル学習
    • タスクに応じた機械学習アーキテクチャを構築し、データセットを用いモデルを学習
  5. 新規データの推論
    • 学習済みモデルを用い新規のデータセットを推論

事前準備

まず必要なライブラリをImportしていきます。

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import cv2
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

import segmentation_models_pytorch as smp
import albumentations as albu

次にSegmentationに用いるデータ(CamVid dataset)をダウンロードします。

DATA_DIR = './data/CamVid/'

# load repo with data if it is not exists
if not os.path.exists(DATA_DIR):
    print('Loading data...')
    os.system('git clone https://github.com/alexgkendall/SegNet-Tutorial ./data')
    print('Done!')

x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'trainannot')

x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'valannot')

x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'testannot')

最後に、サンプルコードで使用する関数を定義していきます。

def get_training_augmentation():
    train_transform = [
        albu.PadIfNeeded(min_height=320, min_width=320, always_apply=True, border_mode=0),
        albu.RandomCrop(height=320, width=320, always_apply=True),
    ]
    return albu.Compose(train_transform)

def get_validation_augmentation():
    """画像のshapeが32で割り切れるようにPaddingするための関数"""
    test_transform = [
        albu.PadIfNeeded(384, 480)
    ]
    return albu.Compose(test_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):
    """Construct preprocessing transform    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

# 可視化用の関数
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

これで準備が完了したのでデータセットの作成に移ります。

データセットの作成

データセットを作成し、学習や推論に適した形に生データを変換します。PyTorchでは基本的にこのデータセットからデータをロードすることで学習、推論を効率的に実施できます。

データセット作成のポイントは次の通りです。

  1. torch.utils.data.Datasetを継承したClassとして作成する
  2. コンストラクタ(__init__)では、生データの格納場所を渡す
  3. Classには最低でも、__getitem__メソッド、__len__メソッドを作成する必要がある
  4. __getitem__メソッドは、学習もしくは推論データと学習する特徴(学習時のみ)を返す
  5. ___len__メソッドはデータセットの長さを返す

それでは、このポイントを踏まえてDataset Classを作成していきます。
コンストラクタ(__init__)では、生データが格納されているディレクトリのPath(images_dir, masks_dir)、__getitem__メソッドで利用する関数を指定します。

# 1. torch.utils.data.Datasetを継承したDataset classを定義
class Dataset(torch.utils.data.Dataset):
    CLASSES = ['sky', 'building', 'pole', 'road', 'pavement', 
               'tree', 'signsymbol', 'fence', 'car', 
               'pedestrian', 'bicyclist', 'unlabelled']
    
    def __init__(
            self, 
            images_dir, # 画像のPath
            masks_dir, # マスク画像のPath
            classes=None, # 推論対象のクラス
            augmentation=None, # augmentation用関数
            preprocessing=None, # 前処理用関数
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # クラス名の文字列('car', 'sky'など)をIDに変換
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing

__getitem__メソッドでは、コンストラクタで定義したディレクトリからデータを1つずつ取得し、必要な処理を適用した後、データセットを返します。

    # 3. 学習用データ(image)と特徴(mask)を返す__getitem__メソッドを作成
    def __getitem__(self, i):
        
        # データの読み込み
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        
        # 学習対象のクラス(例えば、'car')のみを抽出
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # augmentation関数の適用
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # 前処理関数の適用
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask

__len__メソッドでは、コンストラクタで指定した生データの長さを取得します。

    # 4. データセットの長さを返す__len__を作成
    def __len__(self):
        return len(self.ids)

これでデータセットClassが作成できました。

データローダーの作成

学習/推論ループでデータセットを読み込むためにデータローダーを作成します。
データローダーの作成は簡単です。torch.utils.data.DataLoaderにデータセットなどの必要な情報を渡せば作成できます。

from torch.utils.data import DataLoader

# データセットのインスタンスを作成
train_dataset = Dataset(
    x_train_dir, 
    y_train_dir, 
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)

# データローダーの作成
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=12)

推論用のDataLoaderを作成する際は引数の、batch_size=1, shuffle=Falseにするようにしましょう。

それでは、これらを用いてモデルを学習していきます。

モデル学習

先ほど作成したデータセットを用いて、次のステップでモデル学習を実施していきます。

  1. モデルの取得
    • Segmentation-models-pytorchを用いてモデルを取得
  2. 損失関数、最適化関数の設定
    • タスクに応じた損失関数、評価関数を設定
  3. モデル学習
    • データセットを用いモデルを学習し、保存

使用するアーキテクチャを決めてモデルを取得します。
SMPで利用可能なアーキテクチャの一覧はこちらから確認できます。

# パラメータ
ENCODER = 'se_resnext50_32x4d' # バックボーンネットワークの指定
ENCODER_WEIGHTS = 'imagenet' # 使用する学習済みモデル
ACTIVATION = 'sigmoid' # 恒等関数や、マルチクラス用に'softmax2d'にする場合は'None'へ
CLASSES = ['car']
DEVICE = 'cuda'

# SMPを用いて学習済みモデルを取得(アーキテクチャはFPN)
model = smp.FPN(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

上記のコードを実行すると、初回は学習済みモデルのダウンロードが行われた後にモデルが取得できます。取得したモデルはprint文で出力すると構成を確認できます。

print(model)

# 出力結果(抜粋)
FPN(
  (encoder): SENetEncoder(
    (layer0): Sequential(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    )
~~~ 略 ~~~
  (segmentation_head): SegmentationHead(
    (0): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1))
    (1): UpsamplingBilinear2d(scale_factor=4.0, mode=bilinear)
    (2): Activation(
      (activation): Sigmoid()
    )
  )
)

続いて損失関数(loss)及び最適化関数(optimizer)を設定していきます。

# 損失関数
loss = smp.utils.losses.DiceLoss() 

# 評価関数
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

# 最適化関数
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

ここまで来たらエポック単位の学習処理を定義し、モデルを学習します。
注意点として、google colaboratoryを用いている場合はGPUを有効にするのを忘れないようにしましょう。(「ランタイム」 -> 「ランタイムのタイプを変更」から設定できます)

# SMPに用意されているシンプルなループ関数(Train用)
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

# SMPに用意されているシンプルなループ関数(Validation用)
valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

# 学習ループの実行
n_epoch = 40 # エポック数
max_score = 0
for i in range(0, n_epoch):    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    # 評価関数の値が更新されたらモデルを保存
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')

    # エポック25以降は学習率(learning rate)を下げる      
    if i == 25:
        optimizer.param_groups[0]['lr'] = 1e-5
        print('Decrease decoder learning rate to 1e-5!')

これでモデルの学習まで完了しましたので、最後に新規データに対して推論を行い自動車が抽出されるかを確認します。

新規データの推論

モデルの学習まで完了したので、新規データに対してモデルを用いて推論していきます。
推論処理でやることは今まで行ってきた内容とほとんど同じで、

  1. 学習モデルの読み込み
  2. 推論用データセット、データローダーの作成
  3. 推論用データセット及び学習済みモデルを用いて推論
  4. 推論画像の可視化

この順番で行います。
ちなみに4番の可視化は結果確認用ですので、推論をするだけでしたら3番までで完了しています。
推論に先立ち、可視化用の関数をいくつか準備しておきます。

# 可視化用の画像を取得するデータセットを作成(Augmentationなし)
test_dataset_vis = Dataset(
    x_test_dir, y_test_dir, 
    classes=CLASSES,
)

# 可視化用の関数
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

準備が整ったので、推論を実施していきます。

# 1. 学習モデルの読み込み
best_model = torch.load('./best_model.pth')

# 2. 推論用のデータセット、データローダーの作成
test_dataset = Dataset(
    x_test_dir, 
    y_test_dir, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    classes=CLASSES,
)
test_dataloader = DataLoader(test_dataset)

n_data = 2 # 確認するデータの数
for i in range(n_data):
    n = np.random.choice(len(test_dataset))

    # 3. 新規データの取得
    image_vis = test_dataset_vis[n][0].astype('uint8')
    image, gt_mask = test_dataset[n]
    gt_mask = gt_mask.squeeze()
    
    # 3. 新規データの推論
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pr_mask = best_model.predict(x_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
        
    # 4. 可視化
    visualize(
        image=image_vis, 
        ground_truth_mask=gt_mask, 
        predicted_mask=pr_mask
    )

推論結果を可視化した画像は次の通りです。
左から、生画像、政界のマスク画像、予測した推論画像です。自動車の特徴を捉えたモデルが学習できていることが分かります。

終わりに

画像における機械学習のタスクとしては、画像分類、物体検出、セグメンテーションがあり、今回はそのうちSegmentation Models.pytorchを用い、自動車の位置を取得するたセグメンテーションの方法を紹介しました。
画像分類については以前手書き数字の画像MNISTを用いた記事で紹介しておりますので、物体検出の記事を書いたら、画像における機械学習タスクのまとめ記事を執筆予定です。

リンク

コメント

タイトルとURLをコピーしました