import os
import sys
from io import StringIO
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.transforms import Bbox
# plt.rcParams["axes.axisbelow"] = False
from copy import copy
try:
from ..units import nice_array, nice_scale_prefix, set_nice_array
except:
pass
try:
from fastkde import fastKDE
fastKDE_installed = True
except ImportError as e:
print("fastKDE missing - plotScreenImage will use SciPy")
fastKDE_installed = False
try:
from scipy import stats
SciPy_installed = True
except:
SciPy_installed = False
CMAP0 = copy(plt.get_cmap("viridis"))
CMAP0.set_under("white")
CMAP1 = copy(plt.get_cmap("plasma"))
# beamobject = rbf.beam()
[docs]
def density_plot(
particle_group,
key="x",
bins=None,
filename=None,
**kwargs,
):
"""
1D density plot. Also see: marginal_plot
Example:
density_plot(P, 'x', bins=100)
"""
if not bins:
n = len(particle_group)
bins = int(n / 100)
# Scale to nice units and get the factor, unit prefix
x, f1, p1 = nice_array(getattr(particle_group, key))
if key != "charge":
w = abs(particle_group.charge)
else:
w = np.ones(len(getattr(particle_group, key)))
u1 = "" # particle_group.units(key).unitSymbol
ux = p1 + u1
labelx = f"{key} ({ux})"
fig, ax = plt.subplots(**kwargs)
hist, bin_edges = np.histogram(x, bins=bins, weights=w)
hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
hist_width = np.diff(bin_edges)
hist_y, hist_f, hist_prefix = nice_array(hist / hist_width)
ax.bar(hist_x, hist_y, hist_width, color="grey")
# Special label for C/s = A
if u1 == "s":
_, hist_prefix = nice_scale_prefix(hist_f / f1)
ax.set_ylabel(f"{hist_prefix}A")
else:
ax.set_ylabel(f"{hist_prefix}C/{ux}")
ax.set_xlabel(labelx)
if isinstance(filename, str):
plt.savefig(filename)
[docs]
def slice_plot(
particle_group,
xkey="t",
ykey="slice_current",
xlim=None,
nice=True,
include_legend=True,
subtract_mean=True,
bins=None,
filename=None,
**kwargs,
):
"""
slice plot. Also see: marginal_plot
Example:
slice plot(P, 'slice_current', bins=100)
"""
P = particle_group
fig, all_axis = plt.subplots(**kwargs)
ax_plot = [all_axis]
if not bins:
n = len(particle_group)
bins = int(n / 100)
P.slice.slices = bins
X = getattr(P.slice, "slice_" + xkey)
if subtract_mean:
X = X - np.mean(X)
if isinstance(ykey, str):
ykey = [ykey]
if not isinstance(ykey, (list, tuple)):
ykey = [ykey]
if len(ykey) == 1:
include_legend = False
# Only get the data we need
if xlim:
good = np.logical_and(X >= xlim[0], X <= xlim[1])
X = X[good]
else:
xlim = X.min(), X.max()
good = slice(None, None, None) # everything
# X axis scaling
units_x = "s" # str(P.units(xkey))
if nice:
X, factor_x, prefix_x = nice_array(X)
units_x = prefix_x + units_x
else:
factor_x = 1
# set all but the layout
for ax in ax_plot:
ax.set_xlim(xlim[0] / factor_x, xlim[1] / factor_x)
ax.set_xlabel(f"{xkey} ({units_x})")
# Draw for Y1 and Y2
linestyles = ["solid", "dashed"]
ii = -1 # counter for colors
for ix, keys in enumerate([ykey]):
if not keys:
continue
ax = ax_plot[ix]
linestyle = linestyles[ix]
# Check that units are compatible
ulist = [getattr(P.slice, key).units for key in keys] # [I.units(key) for key in keys]
if len(ulist) > 1:
for u2 in ulist[1:]:
assert ulist[0] == u2, f"Incompatible units: {ulist[0]} and {u2}"
# String representation
unit = str(ulist[0])
# Data
data = [np.array(getattr(P.slice, key)[good]) for key in keys]
if nice:
factor, prefix = nice_scale_prefix(np.ptp(data))
unit = prefix + unit
else:
factor = 1
# Make a line and point
for key, dat in zip(keys, data):
#
ii += 1
color = "C" + str(ii)
ax.plot(
X,
dat / factor,
label=f"{key} ({unit})",
color=color,
linestyle=linestyle,
)
ax.set_ylabel(", ".join(keys) + f" ({unit})")
# Collect legend
if include_legend:
lines = []
labels = []
for ax in ax_plot:
a, b = ax.get_legend_handles_labels()
lines += a
labels += b
ax_plot[0].legend(lines, labels, loc="best")
if isinstance(filename, str):
plt.savefig(filename)
[docs]
def marginal_plot(
particle_group,
key1="t",
key2="p",
bins=None,
units=["", ""],
scale=[1, 1],
subtract_mean=[False, False],
cmap=None,
limits=None,
filename=None,
**kwargs,
):
"""
Density plot and projections
Example:
marginal_plot(P, 't', 'energy', bins=200)
"""
if not bins:
n = len(particle_group)
bins = int(np.sqrt(n / 2))
cmap = CMAP0 if cmap is None else cmap
if not isinstance(subtract_mean, (list, tuple)):
subtract_mean = [subtract_mean, subtract_mean]
if not isinstance(scale, (list, tuple)):
scale = [scale, scale]
# Scale to nice units and get the factor, unit prefix
x, f1, p1 = nice_array(
scale[0]
* (getattr(particle_group, key1) - subtract_mean[0] * np.mean(getattr(particle_group, key1)))
)
y, f2, p2 = nice_array(
scale[1]
* (getattr(particle_group, key2) - subtract_mean[1] * np.mean(getattr(particle_group, key2)))
)
x = x / scale[0]
y = y / scale[1]
w = np.full(len(x), 1) #
charge = getattr(particle_group, "charge")
u1, u2 = [getattr(particle_group, k).units for k in [key1, key2]]
ux = p1 + u1
uy = p2 + u2
labelx = f"{key1} ({ux})"
labely = f"{key2} ({uy})"
fig = plt.figure(**kwargs)
gs = GridSpec(4, 4)
ax_joint = fig.add_subplot(gs[1:4, 0:3])
ax_marg_x = fig.add_subplot(gs[0, 0:3])
ax_marg_y = fig.add_subplot(gs[1:4, 3])
# ax_info = fig.add_subplot(gs[0, 3:4])
# ax_info.table(cellText=['a'])
# Proper weighting
ax_joint.hexbin(
x, y, C=w, reduce_C_function=np.sum, gridsize=bins, cmap=cmap, vmin=1e-20
)
if limits is not None:
ax_joint.axis(limits)
# Manual histogramming version
# H, xedges, yedges = np.histogram2d(x, y, weights=w, bins=bins)
# extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
# ax_joint.imshow(H.T, cmap=cmap, vmin=1e-16, origin='lower', extent=extent, aspect='auto')
# Top histogram
# Old method:
# dx = x.ptp()/bins
# ax_marg_x.hist(x, weights=w/dx/f1, bins=bins, color='gray')
hist, bin_edges = np.histogram(x, bins=bins, weights=w)
hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
hist_width = np.diff(bin_edges)
# Special label for C/s = A
if u1 == "s" and abs(np.sum(charge).val) > 0:
hist_y, hist_f, hist_prefix = nice_array(
-np.sum(charge).val * hist / hist_width / len(charge)
)
ax_marg_x.bar(hist_x, hist_y, hist_width, color="gray")
_, hist_prefix = nice_scale_prefix(hist_f / f1)
# print(np.sum(charge).val, hist_f, f1)
ax_marg_x.set_ylabel(f"{hist_prefix}A")
else:
if abs(np.sum(charge).val) > 0:
hist_y, hist_f, hist_prefix = nice_array(
-np.sum(charge).val * hist / hist_width / len(charge)
)
ax_marg_x.bar(hist_x, hist_y, hist_width, color="gray")
ax_marg_x.set_ylabel(f"{hist_prefix}C/{uy}")
else:
hist_y, hist_f, hist_prefix = nice_array(hist)
ax_marg_x.bar(hist_x, hist_y, hist_width, color="gray")
ax_marg_x.set_ylabel(f"{hist_prefix}Counts/{uy}")
if limits is not None:
ax_marg_x.set_xlim(limits[0:2])
# Side histogram
# Old method:
# dy = y.ptp()/bins
# ax_marg_y.hist(y, orientation="horizontal", weights=w/dy, bins=bins, color='gray')
hist, bin_edges = np.histogram(y, bins=bins, weights=w)
hist_x = bin_edges[:-1] + np.diff(bin_edges) / 2
hist_width = np.diff(bin_edges)
if u1 == "s" and abs(np.sum(charge).val) > 0:
hist_y, hist_f, hist_prefix = nice_array(
-np.sum(charge).val * hist / hist_width / len(charge)
)
ax_marg_y.barh(hist_x, hist_y, hist_width, color="gray")
ax_marg_y.set_xlabel(f"{hist_prefix}C/{uy}")
else:
if abs(np.sum(charge).val) > 0:
hist_y, hist_f, hist_prefix = nice_array(
-np.sum(charge).val * hist / hist_width / len(charge)
)
ax_marg_y.barh(hist_x, hist_y, hist_width, color="gray")
ax_marg_y.set_xlabel(f"{hist_prefix}C/{uy}")
else:
hist_y, hist_f, hist_prefix = nice_array(hist)
ax_marg_y.barh(hist_x, hist_y, hist_width, color="gray")
ax_marg_y.set_xlabel(f"{hist_prefix}Counts/{uy}")
if limits is not None:
ax_marg_y.set_ylim(limits[2:])
# Turn off tick labels on marginals
plt.setp(ax_marg_x.get_xticklabels(), visible=False)
plt.setp(ax_marg_y.get_yticklabels(), visible=False)
# Set labels on joint
ax_joint.set_xlabel(labelx)
ax_joint.set_ylabel(labely)
if isinstance(filename, str):
plt.savefig(filename)
[docs]
def plot(self, keys=None, bins=None, type="density", **kwargs):
if keys is not None and (
(isinstance(keys, (list, tuple)) and len(keys) == 1) or isinstance(keys, str)
):
if isinstance(keys, (list, tuple)):
ykey = keys[0]
if type == "slice" or "slice_" in ykey:
return slice_plot(self, ykey=ykey, bins=bins, **kwargs)
elif type == "density":
return density_plot(self, key=ykey, bins=bins, **kwargs)
else:
xkey, ykey = keys
return marginal_plot(self, key1=xkey, key2=ykey, bins=bins, **kwargs)
[docs]
def plotScreenImage(
beam,
keys=["x", "y"],
scale=[1, 1],
iscale=1,
colormap=plt.cm.jet,
size=None,
grid=False,
marginals=False,
limits=None,
screen=False,
use_scipy=False,
subtract_mean=[False, False],
title="",
filename=None,
fig=None,
ax=None, # external Axes
labelsize=None, # axis label font size
**kwargs,
):
import numpy as np
import os
import matplotlib.pyplot as plt
from scipy import stats
# --- Process inputs ---
key1, key2 = keys
if not isinstance(subtract_mean, (list, tuple)):
subtract_mean = [subtract_mean, subtract_mean]
if not isinstance(scale, (list, tuple)):
scale = [scale, scale]
if not isinstance(size, (list, tuple)):
size = [size, size]
# --- Get arrays from beam ---
x, f1, p1 = nice_array(
scale[0] * (getattr(beam, key1) - subtract_mean[0] * np.mean(getattr(beam, key1)))
)
y, f2, p2 = nice_array(
scale[1] * (getattr(beam, key2) - subtract_mean[1] * np.mean(getattr(beam, key2)))
)
u1, u2 = [getattr(beam, k).units for k in keys]
labelx = f"{key1} ({p1 + u1})"
labely = f"{key2} ({p2 + u2})"
# --- Compute PDF ---
if fastKDE_installed and not use_scipy:
myPDF, axes = fastKDE.pdf(x, y, use_xarray=False, **kwargs)
v1, v2 = axes
elif SciPy_installed:
xmin, xmax = x.min(), x.max()
ymin, ymax = y.min(), y.max()
v1, v2 = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([v1.ravel(), v2.ravel()])
values = np.vstack([x, y])
kernel = stats.gaussian_kde(values)
myPDF = np.reshape(kernel(positions).T, v1.shape)
else:
raise Exception("fastKDE or SciPy required")
myPDF = myPDF / myPDF.max() * iscale
# --- Figure / Axes creation ---
if ax is None:
if marginals:
fig = plt.figure(figsize=(12.41, 12.41))
gs = fig.add_gridspec(
2, 2,
width_ratios=(8, 2),
height_ratios=(2, 8),
left=0.1, right=0.9,
bottom=0.1, top=0.95,
wspace=0.05, hspace=0.05
)
ax = fig.add_subplot(gs[1, 0])
ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
ax_histy = fig.add_subplot(gs[1, 1], sharey=ax)
else:
fig = plt.figure(figsize=(10, 10))
fig.subplots_adjust(top=0.95)
ax = fig.add_subplot()
else:
fig = ax.figure
if marginals:
raise ValueError("marginals=True cannot be used when an external ax= is provided.")
# --- Determine size and limits ---
if size[0] is None:
use_size = False
if not screen:
xmin, xmax = v1.min(), v1.max()
ymin, ymax = v2.min(), v2.max()
size = [xmax - xmin, ymax - ymin]
else:
xmin, xmax, ymin, ymax = -15, 15, -15, 15
size = [15, 15]
meanvalx = 0 if subtract_mean[0] else (xmin + xmax)/2
meanvaly = 0 if subtract_mean[1] else (ymin + ymax)/2
else:
use_size = True
meanvalx = 0 if subtract_mean[0] else (v1.max() + v1.min())/2
meanvaly = 0 if subtract_mean[1] else (v2.max() + v2.min())/2
size[0] = size[0]/f1
size[1] = size[1]/f2
# --- Set axis limits ---
if limits is not None:
limits = np.array(limits)
if limits.shape == (2, 2):
ax.set_xlim(limits[0])
ax.set_ylim(limits[1])
elif limits.shape == (2,):
ax.set_xlim(limits)
ax.set_ylim(limits)
elif screen or use_size:
ax.set_xlim([meanvalx - (size[0] + 0.5), meanvalx + (size[0] + 0.5)])
ax.set_ylim([meanvaly - (size[1] + 0.5), meanvaly + (size[1] + 0.5)])
else:
ax.set_xlim([v1.min(), v1.max()])
ax.set_ylim([v2.min(), v2.max()])
# --- Optional marginals ---
if marginals:
hist, bin_edges = myPDF.sum(axis=0)[:-1], v1
hist_x = bin_edges[:-1] + np.diff(bin_edges)/2
hist_width = np.diff(bin_edges)
hist_y, hist_f, hist_prefix = nice_array(hist / hist_width)
ax_histx.bar(hist_x, hist_y, hist_width, color=colormap(hist_y/max(hist_y)))
hist, bin_edges = myPDF.sum(axis=1)[:-1], v2
hist_x = bin_edges[:-1] + np.diff(bin_edges)/2
hist_width = np.diff(bin_edges)
hist_y, hist_f, hist_prefix = nice_array(hist / hist_width)
ax_histy.barh(hist_x, hist_y, hist_width, color=colormap(hist_y/max(hist_y)))
# --- Screen circle and face color ---
if screen:
circ = plt.Circle((meanvalx, meanvaly), 15, facecolor="none")
ax.add_artist(plt.Circle((meanvalx, meanvaly), 15, fill=True, ec="w", fc=colormap(0), zorder=-1))
ax.set_facecolor("k")
else:
circ = plt.Circle((meanvalx, meanvaly), 3*max(size), facecolor="none")
ax.set_facecolor(colormap(0))
# --- Grid ---
if grid:
ax.grid(which="minor", color="w", alpha=0.3, clip_path=circ)
ax.grid(which="major", color="w", alpha=0.55, clip_path=circ)
# --- Main PDF ---
mesh = ax.pcolormesh(v1, v2, myPDF, cmap=colormap, zorder=1, shading="auto")
# --- Axis labels with optional size ---
if labelsize is not None:
ax.set_xlabel(labelx, fontsize=labelsize)
ax.set_ylabel(labely, fontsize=labelsize)
ax.tick_params(axis='both', which='major', labelsize=labelsize)
else:
ax.set_xlabel(labelx)
ax.set_ylabel(labely)
# --- Suptitle ---
file, ext = os.path.splitext(os.path.basename(beam.filename))
# plt.suptitle(title if title else file)
# --- Save file ---
if isinstance(filename, str):
plt.savefig(filename)
plt.draw()
return fig, ax
[docs]
def getScreenImage(
beam,
keys=["x", "y"],
scale=[1, 1],
iscale=1,
colormap=plt.cm.jet,
size=None,
use_scipy=False,
subtract_mean=[False, False],
**kwargs,
):
# Do the self-consistent density estimate
key1, key2 = keys
if not isinstance(subtract_mean, (list, tuple)):
subtract_mean = [subtract_mean, subtract_mean]
if not isinstance(scale, (list, tuple)):
scale = [scale, scale]
if not isinstance(size, (list, tuple)):
size = [size, size]
x, f1, p1 = nice_array(
scale[0] * (beam[key1] - subtract_mean[0] * np.mean(beam[key1]))
)
y, f2, p2 = nice_array(
scale[1] * (beam[key2] - subtract_mean[1] * np.mean(beam[key2]))
)
u1, u2 = [beam[k].units for k in keys]
ux = p1 + u1
uy = p2 + u2
labelx = f"{key1} ({ux})"
labely = f"{key2} ({uy})"
if fastKDE_installed and not use_scipy:
myPDF, axes = fastKDE.pdf(x, y, use_xarray=False, **kwargs)
v1, v2 = axes
elif SciPy_installed:
xmin = x.min()
xmax = x.max()
ymin = y.min()
ymax = y.max()
v1, v2 = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([v1.ravel(), v2.ravel()])
values = np.vstack([x, y])
kernel = stats.gaussian_kde(values)
myPDF = np.reshape(kernel(positions).T, v1.shape)
else:
raise Exception("fastKDE or SciPy required")
# normalise the PDF to 1
myPDF = myPDF / myPDF.max() * iscale
# Define ticks
# Major ticks every 5, minor ticks every 1
use_size = False
xmin, xmax = [min(v1.flatten()), max(v1.flatten())]
ymin, ymax = [min(v2.flatten()), max(v2.flatten())]
return v1, v2, myPDF, colormap, labelx, labely