matplotlibで複数のグラフを並べて表示する方法 plt.subplots

matplotlib

画像データとテーブルデータを対象にmatplotlibを用い複数のグラフを並べて表示する方法を紹介します。

学べること
  • 画像データを複数並べて表示する方法
  • テーブル形式のデータを複数の表示形式で並べて表示する方法

ゴールイメージはこちらです。画像、テーブルデータのグラフを並べて表示しています。

画像に alt 属性が指定されていません。ファイル名: image-1.png
画像に alt 属性が指定されていません。ファイル名: image-5.png

使用データの紹介

本題に入る前に今回使用するデータを紹介します。
画像データは手書き数字のMNIST画像、テーブルデータはボストンの住宅価格を利用します。どちらもscikit-learnから以下のコードで取得可能です。

# 手書き文字画像の取得
from sklearn.datasets import fetch_openml # scikit-learnのデータ取得ライブラリ
x_mnist, t_mnist = fetch_openml('mnist_784', version=1, return_X_y=True)

x_mnist = x_mnist.reshape([x_mnist.shape[0],28,28]) # 画像として扱うために(784 -> (28,28))にReshapeします。
print(x_mnist.shape) # (70000, 28, 28)
print(type(x_mnist)) # numpy.ndarray

# テーブルデータ(ボストンの住宅価格)の取得
import pandas as pd
from sklearn.datasets import load_boston

bsdata = load_boston()
df_bstn = pd.DataFrame(bsdata.data, columns=bsdata.feature_names) #住宅価格の説明変数
df_bstn['Price'] = bsdata.target # 住宅価格

画像データを並べて表示する plt.subplots, ax.imshow

画像を1枚表示するコードは次のようにかけますので、これを少し編集し複数画像を並べて表示していきます。


import matplotlib.pyplot as plt
fig, ax = plt.subplots()

ax.set_title(t_mnist[0], fontsize=16, color='white')
ax.axes.xaxis.set_visible(False) # X軸を非表示に
ax.axes.yaxis.set_visible(False) # Y軸を非表示に
ax.imshow(x_mnist[0],cmap='Greys') # 画像を表示
画像を単体表示

複数の画像を並べて表示させる場合は、figの中に複数のaxを表示させることになります。イメージはこうです。
fig, axの詳細は、ページの最後に公式サイトへのリンクを記載しているのでそちらを参照ください。ここではFigがグラフを表示するフィールドで、axで表示するグラフを指定していると思ってください。

figの中には複数のaxがあるので、どのaxかを引数で指定しながら画像を表示します。
それではコードを見ていきましょう。

n_data = 6 # 表示するデータ数
row=2 # 行数
col=3 # 列数
fig, ax = plt.subplots(nrows=row, ncols=col,figsize=(8,6))

fig.suptitle("MNIST data-set", fontsize=24, color='white')
for i, img in enumerate(x_mnist[:n_data]):
    _r= i//col
    _c= i%col
    ax[_r,_c].set_title(t_mnist[i], fontsize=16, color='white')
    ax[_r,_c].axes.xaxis.set_visible(False) # X軸を非表示に
    ax[_r,_c].axes.yaxis.set_visible(False) # Y軸を非表示に
    ax[_r,_c].imshow(img,cmap='Greys') # 画像を表示
複数のグラフを並列表示

指定した場所(_r, _c)に指定した画像が表示されているのがわかります。

テーブルデータを並べて表示 plt.subplots, ax.hist, ax.scatter

画像の表示ができたので次はテーブルデータを表示してみましょう。
まず簡単にボストン住宅価格データの中身をみてみます。

パッと見ではなんの情報か分かりづらいですが、右端のPriceが住宅価格($1,000単位)です。今回は住宅価格のヒストグラムと、住宅価格と部屋数(RM)、低所得者人口の割合(LSTAT)の散布図を表示します。
各指標の意味を知りたい方は公式サイトに説明がありますので参照ください。

df_table = df_bstn[['Price','RM','LSTAT']]# 表示するデータを指定

row = 1 # グラフの行数
col = 3 # グラフの列数
fig,ax=plt.subplots(row, col, figsize=(12,4))
for i, key in enumerate(df_table.keys()):    
    if key == 'Price': # 住宅価格はヒストグラムで表示
        ax[i].hist(df_table[key])
    else: # 住宅価格との散布図を表示
        ax[i].scatter(df_bstn['Price'], df_table[key])
    ax[i].set_title(key, fontsize=16) # グラフのタイトル
    ax[i].set_xlabel('Price ($1,000)', fontsize=12) # グラフのX軸名
    # 注:一行の場合はaxの引数指定はax[i]のように1次元で指定。

左から順番に住宅価格のヒストグラム、住宅価格と部屋数、低所得者層の割合との散布図が表示されます。

参考リンク

  • matplotlib Usage guide
    matplotlibの公式サイト。FIgureの構成要素やFig, axes, axisの説明が記載されています。

コメント

タイトルとURLをコピーしました