前回の記事でmatplotlib による3Dグラフの描き方を紹介しました。

今回は、もっと簡単に3Dグラフが描けるようクラス化してみました。
データを与えるだけでグリグリ回る3Dグラフが出来上がります。
ついで、3Dグラフを描画するためのデモデータのサンプルもクラス化してみました。
興味のある方は是非、ご一読下さい。
最初の準備
今回紹介するクラスを使うには matplotlib.pyplot と mpl_toolkits.mplot3d をインポートする必要があるので、下記2行を冒頭に追記して下さい。
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
Anaconda 経由で Python をインストールしていると既に含まれていますが、python 本家のサイトから Pythonをインストールしている場合、pip コマンドによるインストールが必要となります。
pip install matplotlib
Plot3Dクラスの使い方
クラス名は Plot3D という名前にしました。
使い方は次の通りです。
ちなみに、「XYZラベルの設定」は省略可能で、省略した場合は”X”、”Y”、”Z” が軸のラベル名として設定されます。

このクラスを使って3D空間に散布図と折れ線を引くには、次の様になります。
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#描画するデータの定義
x=[1,2,3,4,5]
y=[5,4,3,2,1]
z=[1,2,3,2,1]
#3Dグラフの表示
plt3d = Plot3D("3Dグラフ")
plt3d.set_data(x,y,z)
plt3d.scatter()
plt3d.plot()
plt3d.show()
下記が表示されるグラフです。

Plot3Dのリファレンス
ソースコードのコメントを使ってリファレンスを作ってみました。
| 機能 | メソッドとパラメータ |
|---|---|
| コンストラクタ | __init__(self,title,size=15,color="black") title:str グラフタイトル size:int グラフタイトルのフォントサイズ color:str グラフタイトルの色 |
| 軸ラベルのサイズと色を設定 | set_label(self,xlabel,ylabel,zlabel,size=12,color="black") xlabel:int X軸のラベル ylabel:int Y軸のラベル zlabel:int Z軸のラベル size : int ラベルのフォントサイゼ宇 color:str ラベルの色 |
| 描画データの登録 | set_data(self,x,y,z,values = None) x:int Xの値 y:int Yの値 z:int Zの値 value:int XYZ座標のデータが持つ値(初期値はNone) |
| 線の描画 | plot(self,width=1,color="black") size:int 線の太さ color:str 線の色 |
| 散布図の描画 | scatter(self,size=1,color="black") size :int マーカーサイズ color : str マーカー色 |
| カラーバー付き散布図の描画 | cmap_scatter(self,size=1,cmap="hsv") size :int マーカーサイズ cmap : str カラーマップ名 |
| 曲面の描画 | surface(self,cmap="hsv",alpha=0.8) cmap : str カラーマップ名 alpha:float 面の透明度 |
| ワイヤーフレームの描画 | wireframe(self,color="blue") color : str ワイヤーの色 |
| 底面の等高線描画 | contour(self,color="black",offset=None) color : str 等高線の色 offset : int,double 等高線のZ位置(0からのオフセット) |
| グラフの描画 | show() |
ソースコード
クラスのソースコードは以下の通りです。
基本的な内容をクラス化していますので、必要に応じて修正、追記の上、お使いください。
class Plot3D:
def __init__(self,title,size=15,color="black"):
'''
コンストラクタ
Parameters
----------
title:str グラフタイトル
size:int グラフタイトルのフォントサイズ
color:str グラフタイトルの色
'''
#タイトルで漢字が使えるようフォントを設定
plt.rcParams['font.family'] = 'Meiryo'
#描画エリアの作成
self.fig = plt.figure()
self.ax = self.fig.add_subplot(projection='3d')
self.ax.set_title(title,size=size,color="black")
self.set_label('X','Y','Z')
self.x = []
self.y = []
self.z = []
self.value = []
def set_label(self,xlabel,ylabel,zlabel,size=12,color="black") :
'''
軸ラベルのサイズと色を設定
Parameters
----------
xlabel:int X軸のラベル
ylabel:int Y軸のラベル
zlabel:int Z軸のラベル
size : int ラベルのフォントサイゼ宇
color:str ラベルの色
'''
self.ax.set_xlabel(xlabel,size=size,color=color)
self.ax.set_ylabel(ylabel,size=size,color=color)
self.ax.set_zlabel(zlabel,size=size,color=color)
def set_data(self,x,y,z,values = None):
'''
描画データの登録
Parameters
----------
x:int Xの値
y:int Yの値
z:int Zの値
value:int XYZ座標のデータが持つ値
'''
self.x = x
self.y = y
self.z = z
self.values = values
def plot(self,width=1,color="black"):
'''
線の描画
Parameters
----------
size:int 線の太さ
color:str 線の色
'''
self.ax.plot(self.x,self.y,self.z,color=color,linewidth=width)
def scatter(self,size=1,color="black"):
'''
散布図の描画
Parameters
----------
size :int マーカーサイズ
color : str マーカー色
'''
self.ax.scatter(self.x,self.y,self.z,s=size,color=color)
def cmap_scatter(self,size=1,cmap="hsv"):
'''
カラーバー付き散布図の描画
Parameters
----------
size :int マーカーサイズ
cmap : str カラーマップ名
'''
val = self.z if self.values is None else self.values
mappable = self.ax.scatter(self.x,self.y,self.z,c=val,s=size,cmap=cmap)
self.fig.colorbar(mappable, ax=self.ax)
def surface(self,cmap="hsv",alpha=0.8):
'''
曲面の描画
Parameters
----------
cmap : str カラーマップ名
alpha:float 面の透明度
'''
self.ax.plot_surface(self.x,self.y,self.z,cmap=cmap,alpha=alpha)
def wireframe(self,color="blue"):
'''
ワイヤーフレームの描画
Parameters
----------
color : str ワイヤーの色
'''
self.ax.plot_wireframe(self.x,self.y,self.z,color=color)
def contour(self,color="black",offset=None):
'''
底面の等高線描画
Parameters
----------
color : str 等高線の色
offset : int,double 等高線のZ位置(0からのオフセット)
'''
offset = np.min(self.z) if offset is None else offset
self.ax.contour(self.x,self.y,self.z,colors="black",offset=offset)
def show(self):
'''
グラフの描画
'''
plt.show()
3Dグラフ用サンプルデータ自動生成クラス
3Dグラフでググってみると、三角関数を使った綺麗なグラフの表示例が見つかります。
せっかく3Dグラフを簡単に描けるクラスを作ったので、このクラスを使って綺麗なグラフを描こうと思い、数式をPythonでソースコード化したものを探し回ってみました。
しかし、あちこちのサイトに少しづつ点在している状態だったので、見つけた内容をほぼそのままメソッド化し、データ自動生成クラスを作成してみました。
私自身あまり数学に得意ではないので、中身は良く分かっていませんが、とりあえずインスタンスを生成してメソッド( sample1~sample6 )を呼べば、それらしいデータが生成されます。
使い方
Create3DData というクラス名で作成しています。
クラス内では numpy を使っているので、あらかじめインポートしておいてください。
使い方は、インスタンスを生成し、sample1 ~ sample6 のいずれかのメソッドを呼ぶだけです。
戻り値として x,y,z が返って来ますので、Plot3D の set_data メソッドにセットします。
#データの自動生成
data = Create3DData()
x,y,z = data.sample6()
#3Dグラフの描画
plt3d = Plot3D("3Dグラフ")
plt3d.set_data(x,y,z)
plt3d.scatter(size=1,color='red')
plt3d.show()

ソースコード
ソースコードは以下の通りです。
コメントは記述していませんが、コンストラクタにデータ個数を指定することができます。
データ個数が多いと3Dグラフの表示が重くなるので、初期値は50にしています。
class Create3DData:
def __init__(self,size=50):
self.size = size
def sample1(self,a=2,b=9):
u=np.linspace(0,2*np.pi,self.size)
v=np.linspace(0,2*np.pi,self.size)
u,v=np.meshgrid(u,v)
x = (b + a*np.cos(u)) * np.cos(v)
y = (b + a*np.cos(u)) * np.sin(v)
z = a * np.sin(u)
return x,y,z
def sample2(self):
x = np.linspace(-3*np.pi,3*np.pi,self.size)
y = np.linspace(-3*np.pi,3*np.pi,self.size)
# X、Yデータの作成
x, y = np.meshgrid(x, y)
# 高度データの作成
z = np.cos(x/np.pi) * np.sin(y/np.pi)
return x,y,z
def sample3(self):
x = np.linspace(-3*np.pi,3*np.pi,self.size)
y = np.linspace(-3*np.pi,3*np.pi,self.size)
# X、Yデータの作成
x, y = np.meshgrid(x, y)
#z = 50 * np.cos(np.sqrt(x**2+y**2)/10)
z = 50 * np.cos(x*y/2000)
return x,y,z
def sample4(self):
x = np.arange(-3, 3, 6/self.size) # x点として[-2, 2]まで0.05刻みでサンプル
y = np.arange(-3, 3, 6/self.size) # y点として[-2, 2]まで0.05刻みでサンプル
x, y = np.meshgrid(x, y) # 上述のサンプリング点(x,y)を使ったメッシュ生成
z = np.exp(-(x**2 + y**2)) #exp(-(x^2+y^2)) を計算してzz座標へ格納する。
return x,y,z
def sample5(self):
u = np.linspace(0, 2 * np.pi, self.size)
v = np.linspace(0, np.pi, self.size)
r = np.hstack((np.linspace(0, 10, 50),np.linspace(10, 0, 50)))
x = np.outer(np.cos(u), np.sin(v))
y = np.outer(np.sin(u), np.sin(v))
z = np.outer(np.ones(np.size(u)), np.cos(v))
return x,y,z
def sample6(self):
x = np.arange(-5, 5, 10/self.size)
y = np.arange(-5, 5, 10/self.size)
x, y = np.meshgrid(x, y)
r = np.sqrt(x**2 + y**2)
z = np.sin(r)
return x,y,z
3Dグラフのサンプル
6つのサンプルを順番に表示するサンプルプログラムを作りましたので、紹介しておきます。
dt = Create3DData(50)
datas = [dt.sample1(),dt.sample2(),dt.sample3(),dt.sample4(),dt.sample5(),dt.sample6()]
for data in datas:
plt3d = Plot3D("3Dグラフ")
plt3d.set_data(data[0],data[1],data[2])
#plt3d.surface()
#plt3d.wireframe()
#plt3d.scatter(size=1)
plt3d.cmap_scatter(size=1)
plt3d.contour()
plt3d.show()
上記のプログラムを実行すると、次のグラフが次々に表示されます。
カラーバーを使いたくない場合は scatter メソッドの方を使ってください。

以下は plt3d.wireframe メソッドを使った場合のワイヤーフレーム描画です。

以下は plt3d. surface メソッドを使った場合の曲面描画です。

以下は plt3d. surface と plt3d.scatter メソッドを使った場合の描画です。

まとめ
今回は 3Dグラフを簡単に描くためのクラス(Plot3D)と、3Dグラフのデータを自動生成するクラス(Create3DData)について紹介しました。
plot3DについてはX,Y,Zのデータを渡すだけで、簡単に3Dグラフを表示することが可能です。
必要最小限の機能しか実装していませんので、必要に応じてカスタマイズしてお使い下さい。
今回の記事が皆様のお役に立てれば幸いです。
コメント