【コピペで使える】3D plot を簡単に描くクラスを作ってみました。(by Python)

当ページのリンクには広告が含まれています。

前回の記事でmatplotlib による3Dグラフの描き方を紹介しました。

今回は、もっと簡単に3Dグラフが描けるようクラス化してみました。

データを与えるだけでグリグリ回る3Dグラフが出来上がります。

ついで、3Dグラフを描画するためのデモデータのサンプルもクラス化してみました。

興味のある方は是非、ご一読下さい。

目次

最初の準備

今回紹介するクラスを使うには matplotlib.pyplot と mpl_toolkits.mplot3d をインポートする必要があるので、下記2行を冒頭に追記して下さい。

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

Anaconda 経由で Python をインストールしていると既に含まれていますが、python 本家のサイトから Pythonをインストールしている場合、pip コマンドによるインストールが必要となります。

pip install matplotlib

Plot3Dクラスの使い方

クラス名は Plot3D という名前にしました。

使い方は次の通りです。

ちなみに、「XYZラベルの設定」は省略可能で、省略した場合は”X”、”Y”、”Z” が軸のラベル名として設定されます。

このクラスを使って3D空間に散布図と折れ線を引くには、次の様になります。

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

#描画するデータの定義

x=[1,2,3,4,5]
y=[5,4,3,2,1]
z=[1,2,3,2,1]

#3Dグラフの表示
plt3d = Plot3D("3Dグラフ")
plt3d.set_data(x,y,z)
plt3d.scatter()

plt3d.plot()
plt3d.show()

下記が表示されるグラフです。

Plot3Dのリファレンス

ソースコードのコメントを使ってリファレンスを作ってみました。

機能メソッドとパラメータ
コンストラクタ __init__(self,title,size=15,color="black")

title:str       グラフタイトル
size:int        グラフタイトルのフォントサイズ
color:str      グラフタイトルの色
軸ラベルのサイズと色を設定set_label(self,xlabel,ylabel,zlabel,size=12,color="black") 

xlabel:int     X軸のラベル
ylabel:int     Y軸のラベル
zlabel:int     Z軸のラベル
size : int       ラベルのフォントサイゼ宇
color:str      ラベルの色
描画データの登録set_data(self,x,y,z,values = None)

x:int        Xの値
y:int      Yの値
z:int       Zの値
value:int   XYZ座標のデータが持つ値(初期値はNone)
線の描画plot(self,width=1,color="black")

size:int     線の太さ
color:str    線の色
散布図の描画scatter(self,size=1,color="black")

size :int     マーカーサイズ
color : str    マーカー色
カラーバー付き散布図の描画cmap_scatter(self,size=1,cmap="hsv")

size :int      マーカーサイズ
cmap : str    カラーマップ名
曲面の描画surface(self,cmap="hsv",alpha=0.8)

cmap : str    カラーマップ名
alpha:float    面の透明度
ワイヤーフレームの描画wireframe(self,color="blue")

color : str    ワイヤーの色
底面の等高線描画contour(self,color="black",offset=None)

color : str            等高線の色
offset : int,double    等高線のZ位置(0からのオフセット)
グラフの描画show()

ソースコード

クラスのソースコードは以下の通りです。

基本的な内容をクラス化していますので、必要に応じて修正、追記の上、お使いください。

class Plot3D:
    
    def __init__(self,title,size=15,color="black"):
        '''
        コンストラクタ
        
        Parameters
        ----------
        title:str      グラフタイトル
        size:int       グラフタイトルのフォントサイズ
        color:str      グラフタイトルの色
        '''
        
        #タイトルで漢字が使えるようフォントを設定
        plt.rcParams['font.family'] = 'Meiryo'
        
        #描画エリアの作成
        self.fig = plt.figure()
        self.ax = self.fig.add_subplot(projection='3d')
        self.ax.set_title(title,size=size,color="black")
        self.set_label('X','Y','Z')
        self.x = []
        self.y = []
        self.z = []
        self.value = []
    
    def set_label(self,xlabel,ylabel,zlabel,size=12,color="black") :
        '''
        軸ラベルのサイズと色を設定
        
        Parameters
        ----------
        xlabel:int     X軸のラベル
        ylabel:int     Y軸のラベル
        zlabel:int     Z軸のラベル
        size : int      ラベルのフォントサイゼ宇
        color:str      ラベルの色
        '''
        
        self.ax.set_xlabel(xlabel,size=size,color=color)
        self.ax.set_ylabel(ylabel,size=size,color=color)
        self.ax.set_zlabel(zlabel,size=size,color=color)
    
    def set_data(self,x,y,z,values = None):
        '''
        描画データの登録
        
        Parameters
        ----------
        x:int     Xの値
        y:int     Yの値
        z:int     Zの値
        value:int  XYZ座標のデータが持つ値
        '''
        self.x = x
        self.y = y
        self.z = z
        self.values = values
    
    def plot(self,width=1,color="black"):
        '''
        線の描画
        
        Parameters
        ----------
        size:int     線の太さ
        color:str    線の色
        '''
        self.ax.plot(self.x,self.y,self.z,color=color,linewidth=width)
    
    def scatter(self,size=1,color="black"):
        '''
        散布図の描画
        
        Parameters
        ----------
        size :int     マーカーサイズ
        color : str    マーカー色
        '''
        self.ax.scatter(self.x,self.y,self.z,s=size,color=color)
    
    def cmap_scatter(self,size=1,cmap="hsv"):
        '''
        カラーバー付き散布図の描画
        
        Parameters
        ----------
        size :int     マーカーサイズ
        cmap : str    カラーマップ名
        '''
        val = self.z if self.values is None else self.values
        mappable = self.ax.scatter(self.x,self.y,self.z,c=val,s=size,cmap=cmap)
        self.fig.colorbar(mappable, ax=self.ax)   
    
    def surface(self,cmap="hsv",alpha=0.8):
        '''
         曲面の描画
        
        Parameters
        ----------
        cmap : str    カラーマップ名
        alpha:float   面の透明度
        '''
        self.ax.plot_surface(self.x,self.y,self.z,cmap=cmap,alpha=alpha)
    
    def wireframe(self,color="blue"):
        '''
         ワイヤーフレームの描画
        
        Parameters
        ----------
        color : str    ワイヤーの色
        '''
        self.ax.plot_wireframe(self.x,self.y,self.z,color=color)
    
    def contour(self,color="black",offset=None):
        '''
        底面の等高線描画
        
        Parameters
        ----------
        color : str           等高線の色
        offset : int,double    等高線のZ位置(0からのオフセット)
        '''
        
        offset = np.min(self.z) if offset is None else offset
        self.ax.contour(self.x,self.y,self.z,colors="black",offset=offset)
    
    def show(self):
        '''
        グラフの描画
        '''
        plt.show()

3Dグラフ用サンプルデータ自動生成クラス

3Dグラフでググってみると、三角関数を使った綺麗なグラフの表示例が見つかります。

せっかく3Dグラフを簡単に描けるクラスを作ったので、このクラスを使って綺麗なグラフを描こうと思い、数式をPythonでソースコード化したものを探し回ってみました。

しかし、あちこちのサイトに少しづつ点在している状態だったので、見つけた内容をほぼそのままメソッド化し、データ自動生成クラスを作成してみました。

私自身あまり数学に得意ではないので、中身は良く分かっていませんが、とりあえずインスタンスを生成してメソッド( sample1~sample6 )を呼べば、それらしいデータが生成されます。

使い方

Create3DData というクラス名で作成しています。

クラス内では numpy を使っているので、あらかじめインポートしておいてください。

使い方は、インスタンスを生成し、sample1 ~ sample6 のいずれかのメソッドを呼ぶだけです。

戻り値として x,y,z が返って来ますので、Plot3D の set_data メソッドにセットします。

#データの自動生成
data = Create3DData()
x,y,z = data.sample6()

#3Dグラフの描画
plt3d = Plot3D("3Dグラフ")
plt3d.set_data(x,y,z)
plt3d.scatter(size=1,color='red')
plt3d.show()

ソースコード

ソースコードは以下の通りです。

コメントは記述していませんが、コンストラクタにデータ個数を指定することができます。

データ個数が多いと3Dグラフの表示が重くなるので、初期値は50にしています。

class Create3DData:
    
    def __init__(self,size=50):
        self.size = size
    
    def sample1(self,a=2,b=9):
        u=np.linspace(0,2*np.pi,self.size)
        v=np.linspace(0,2*np.pi,self.size)
        u,v=np.meshgrid(u,v)
        x = (b + a*np.cos(u)) * np.cos(v)
        y = (b + a*np.cos(u)) * np.sin(v)
        z = a * np.sin(u) 
        return x,y,z    
    
    def sample2(self):
        x = np.linspace(-3*np.pi,3*np.pi,self.size)
        y = np.linspace(-3*np.pi,3*np.pi,self.size)
        # X、Yデータの作成
        x, y = np.meshgrid(x, y)
        # 高度データの作成
        z = np.cos(x/np.pi) * np.sin(y/np.pi)
        return x,y,z
    
    def sample3(self):
        x = np.linspace(-3*np.pi,3*np.pi,self.size)
        y = np.linspace(-3*np.pi,3*np.pi,self.size)
        # X、Yデータの作成
        x, y = np.meshgrid(x, y)
        #z = 50 * np.cos(np.sqrt(x**2+y**2)/10)
        z = 50 * np.cos(x*y/2000)
        return x,y,z
    
    def sample4(self):
        x = np.arange(-3, 3, 6/self.size) # x点として[-2, 2]まで0.05刻みでサンプル
        y = np.arange(-3, 3, 6/self.size)  # y点として[-2, 2]まで0.05刻みでサンプル
        x, y = np.meshgrid(x, y)  # 上述のサンプリング点(x,y)を使ったメッシュ生成
        z = np.exp(-(x**2 + y**2))  #exp(-(x^2+y^2))  を計算してzz座標へ格納する。
        return x,y,z
    
    def sample5(self):
        u = np.linspace(0, 2 * np.pi, self.size)
        v = np.linspace(0, np.pi, self.size)
        r = np.hstack((np.linspace(0, 10, 50),np.linspace(10, 0, 50)))
        x = np.outer(np.cos(u), np.sin(v)) 
        y = np.outer(np.sin(u), np.sin(v))
        z = np.outer(np.ones(np.size(u)), np.cos(v))
        return x,y,z
    
    def sample6(self):
        x = np.arange(-5, 5, 10/self.size)
        y = np.arange(-5, 5, 10/self.size)
        x, y = np.meshgrid(x, y)
        r = np.sqrt(x**2 + y**2)
        z = np.sin(r)
        return x,y,z

3Dグラフのサンプル

6つのサンプルを順番に表示するサンプルプログラムを作りましたので、紹介しておきます。

dt = Create3DData(50)

datas = [dt.sample1(),dt.sample2(),dt.sample3(),dt.sample4(),dt.sample5(),dt.sample6()]

for data in datas:
    plt3d = Plot3D("3Dグラフ")
    plt3d.set_data(data[0],data[1],data[2])
    #plt3d.surface()
    #plt3d.wireframe()
    #plt3d.scatter(size=1)
    plt3d.cmap_scatter(size=1)
    plt3d.contour()
    plt3d.show()

上記のプログラムを実行すると、次のグラフが次々に表示されます。

カラーバーを使いたくない場合は scatter メソッドの方を使ってください。

以下は plt3d.wireframe メソッドを使った場合のワイヤーフレーム描画です。

以下は plt3d. surface メソッドを使った場合の曲面描画です。

以下は plt3d. surface と plt3d.scatter メソッドを使った場合の描画です。

まとめ

今回は 3Dグラフを簡単に描くためのクラス(Plot3D)と、3Dグラフのデータを自動生成するクラス(Create3DData)について紹介しました。

plot3DについてはX,Y,Zのデータを渡すだけで、簡単に3Dグラフを表示することが可能です。

必要最小限の機能しか実装していませんので、必要に応じてカスタマイズしてお使い下さい。

今回の記事が皆様のお役に立てれば幸いです。

よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!

コメント

コメントする

目次