'''
Functions to plot fields
'''
from .field import *
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
[docs]
def plot(fields,dpi=100,show=True,filename=None,imag=False):
"""Plot fields using matplotlib
Args:
fields: a list of Field objects
dpi: dpi (resolution) of specified image
show: whether or not to show the plot
filename: output filename for plotfile
imag: whether to plot 'real' or 'imag' component of fields
Returns:
none
"""
try:
nfields = len(fields)
except TypeError:
fields = [fields]
nfields = len(fields)
if nfields == 1:
nrows = 1; ncols = 1
elif nfields == 2:
nrows = 1; ncols = 2
elif nfields == 3:
nrows = 1; ncols = 3
elif nfields == 4:
nrows = 2; ncols = 2
else:
# try to factor
#nrows = 6
#while (nfields % nrows != 0):
# nrows -= 1
#ncols = int(nfields / nrows)
# force square
nrows = int(np.ceil(nfields **0.5))
ncols = int(np.ceil(nfields **0.5))
print(f"Automatically set {nrows = }, {ncols = }")
#else:
# raise RuntimeError("nfields > 4 not currently supported")
fig = plt.figure(figsize=(ncols*3.33,nrows*3.33),dpi=dpi)
for i in np.ndindex(nrows,ncols):
ifield = i[0]*ncols + i[1]
if ifield >= nfields:
continue
field = fields[ifield]
# create new axis for plot
if fields[0].dim <= 2:
ax = fig.add_subplot(nrows,ncols,ifield+1)
else:
ax = fig.add_subplot(nrows,ncols,ifield+1, projection='3d')
# grab either the real or imag part of fields (depending on imag input arg)
if field.is_complex():
if imag == False:
print(f"Note: not plotting imaginary part of field {ifield}")
data = field.data.real
else:
print(f"Note: only plotting imaginary part of field {ifield}")
data = field.data.imag
else:
if (imag==True):
print("Warning: imag set to 'True' but fields are purely real")
data = field.data
# set title
if nfields != 1:
ax.set_title(f'Field {ifield}')
if field.dim == 1:
if nfields != 1:
ax.set_title(f'Field {ifield}')
ax.set_xlabel('x')
ax.set_ylabel('field value')
ax.plot(field.coords,data)
if field.dim == 2:
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.axis('equal')
pc = ax.pcolormesh(field.coords[:,:,0],field.coords[:,:,1], data,shading='auto',cmap='coolwarm')
cb = fig.colorbar(pc,ax=ax)
cb.set_label('field value')
if field.dim == 3:
from mpl_toolkits.mplot3d import Axes3D
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z',labelpad=0)
#ax.axis('equal') # not supported
scatter = ax.scatter(field.coords[:,:,:,0], field.coords[:,:,:,1],field.coords[:,:,:,2],c=data,cmap='coolwarm')
# from https://stackoverflow.com/questions/18195758/set-matplotlib-colorbar-size-to-match-graph
#divider = make_axes_locatable(ax)
#cax = divider.append_axes("right", size="5%", pad=0.05)
#cb = fig.colorbar(scatter,cax=cax)
cb = fig.colorbar(scatter,ax=ax,pad=0.2,shrink=0.5)
#cb.set_label('field value')
plt.tight_layout()
if filename:
plt.savefig(filename)
if show:
plt.show()
plt.close()