subplots()やadd_subplot()で複数の行数・列数のAxesを生成すると、Axesオブジェクトの2次元の配列となる。
この結果に対して一律に処理をしたい場合(たとえば軸の値や凡例を設定したい、アスペクトを揃えたいなどの場合)、いちいち二重ループを回すのが面倒。
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import numpy as np import matplotlib.pyplot as plt fig, axs = plt.subplots(2, 2) x = np.linspace(-np.pi, np.pi) n = 1 for row in axs: for ax in row: ax.plot(x, np.sin(n*x), label="n={}".format(n)) n += 1 ax.set_ylim(-1.2, 1.8) ax.legend(loc='upper left') plt.show() |
変換方法の1つは、以下のように1次元配列で取り出してしまう方法
|
1 |
axs_1d = [ax for row in axs for ax in row] |
あるいは、以下のように2次元配列を1次元に変換する方法(当初、reshape(1, -1)[0]のようなことをしていたが、reshape(-1)とすればよいことがわかった)。
|
1 |
axs_1d = axs.reshape(-1) |
こうすると1次元配列axs_1dで2次元のaxsの全要素に対してアクセス可能になる。
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import numpy as np import matplotlib.pyplot as plt fig, axs = plt.subplots(2, 2) axs_1d = axs.reshape(-1) x = np.linspace(-np.pi, np.pi) for n, ax in enumerate(axs_1d): ax.plot(x, np.sin(n*x), label="n={}".format(n)) ax.set_ylim(-1.2, 1.8) ax.legend(loc='upper left') plt.show() |
このほかにflatten()、ravel()を使う方法もある。flatten()はコピーを返すが、Axesオブジェクトへの参照先は変わらないので同じ効果。
各グラフにカウンターの値を適用するときはenumerate、他のリストなどと同時に変えていくときはzipを使う。