import math
import matplotlib.pyplot as plt
from copy import copy
import numpy as np
from ..units import nice_array, nice_scale_prefix
from mpl_axes_aligner import align
from ..Twiss import twissParameter, twiss_defaults
from laura.translator.converters.converter import translate_elements
# from units import nice_array, nice_scale_prefix
CMAP0 = copy(plt.get_cmap("viridis"))
CMAP0.set_under("white")
CMAP1 = copy(plt.get_cmap("plasma"))
[docs]
def trans(M):
return [[M[j][i] for j in range(len(M))] for i in range(len(M[0]))]
[docs]
def find_nearest(array, value):
idx = np.searchsorted(array, value, side="left")
if idx > 0 and (
idx == len(array)
or math.fabs(value - array[idx - 1]) < math.fabs(value - array[idx])
):
return idx - 1
else:
return idx
[docs]
def ASTRA_TW_FieldMap(fielddat, start, stop, cells, p):
zpos = list(fielddat[:, 0])
startpos = zpos.index(start)
stoppos = zpos.index(stop)
halfcell1 = fielddat[:startpos]
halfcell2 = fielddat[stoppos:]
rfcell = fielddat[startpos:stoppos]
n_cells = int(cells / p)
cell_length = rfcell[-1, 0] - rfcell[0, 0]
dat = list(halfcell1)
for i in range(0, n_cells + 1, 1):
dat += list(1.0 * rfcell)
rfcell[:, 0] += cell_length
halfcell2[:, 0] += n_cells * cell_length
dat += list(halfcell2)
dat = np.array(dat)
return dat
[docs]
def fieldmap_data(element, master_lattice):
"""
Loads the fieldmap in absolute coordinates.
If a fieldmaps dict is given, these will be used instead of loading the file.
"""
# Position
try:
if element.field_reference_position == "start":
offset = element.physical.start.z
elif element.field_reference_position == "end":
offset = element.physical.end.z
else:
offset = element.physical.middle.z
except AttributeError:
offset = element.phyiscal.start.z
# Scaling
try:
scale = element.field_amplitude
except AttributeError:
scale = element.simulation.field_amplitude
if element.hardware_type.lower() == "rfcavity":
scale = scale / 1e6
# file
element = translate_elements(elements=[element], master_lattice=master_lattice)[element.name]
element.update_field_definition()
field = element.simulation.field_definition
data = field.get_field_data(code="astra")
if field.field_type == "1DElectroDynamic" and field.cavity_type == "TravellingWave":
dat = ASTRA_TW_FieldMap(
np.transpose([field.z.value.val, field.Ez.value.val]),
field.start_cell_z,
field.end_cell_z,
element.n_cells,
field.mode_denominator,
)
else:
dat = data
dat[:, 0] += offset
x = dat[:, 1]
normalise = max(x.min(), x.max(), key=abs)
dat[:, 1] *= scale / normalise
return dat
[docs]
class magnet_plotting_data:
def __init__(
self,
kinetic_energy=None,
):
if kinetic_energy is not None:
self.z, self.kinetic_energy = trans(kinetic_energy)
else:
self.z, self.kinetic_energy = ([0.0], [1.0])
[docs]
def half_rectangle(self, e, half_height):
return np.array(
[
[e.physical.start.z, 0],
[e.physical.start.z, half_height],
[e.physical.end.z, half_height],
[e.physical.end.z, 0],
]
)
[docs]
def full_rectangle(self, e, half_height, width=0):
return np.array(
[
[e.physical.start.z - width, -half_height],
[e.physical.start.z - width, half_height],
[e.physical.end.z + width, half_height],
[e.physical.end.z + width, -half_height],
]
)
[docs]
def quadrupole(self, e):
# if e.gradient is None:
strength = np.sign(e.k1l) * 0.5
# else:
# idx = find_nearest(self.z, e.middle[2])
# ke = self.kinetic_energy[idx]
# strength = 1.0 / (3.3356 * ke / 1e6) * e.gradient
return self.half_rectangle(e, strength), "red"
[docs]
def sextupole(self, e):
# if e.gradient is None:
strength = np.sign(e.k2l) * 0.5
# else:
# idx = find_nearest(self.z, e.middle[2])
# ke = self.kinetic_energy[idx]
# strength = 1.0 / (3.3356 * ke / 1e6) * e.gradient
return self.half_rectangle(e, strength), "green"
[docs]
def dipole(self, e):
strength = np.sign(e.angle) * 0.4 # e.angle
return self.half_rectangle(e, strength), "blue"
[docs]
def beam_position_monitor(self, e):
strength = 0.1 # e.angle
return self.full_rectangle(e, strength), "purple"
[docs]
def screen(self, e):
strength = 0.33 # e.angle
return self.full_rectangle(e, strength), "green"
[docs]
def aperture(self, e):
strength = 0.15 # e.angle
return self.full_rectangle(e, strength, width=0.01), "black"
[docs]
def wall_current_monitor(self, e):
strength = 0.33 # e.angle
return self.full_rectangle(e, strength), "brown"
[docs]
def load_elements(
lattice,
bounds=None,
sections="All",
types=["RFCavity", "Solenoid"],
kinetic_energy=None,
verbose=False,
scale=1,
):
master_lattice = lattice.global_parameters["master_lattice"]
fmap = {}
mpd = magnet_plotting_data(kinetic_energy=kinetic_energy)
for t in types:
fmap[t] = {}
if sections == "All":
elements = [lattice[e["name"]] for e in lattice.getElementType(t)]
else:
elements = []
for s in sections:
elements += [lattice[e["name"]] for e in lattice[s].getElementType(t)]
if bounds is not None:
elements = [
e
for e in elements
if e.physical.start.z <= bounds[1]
and e.physical.end.z >= bounds[0] - 0.1
]
for e in elements:
if (
(t == "RFCavity" or t == "Solenoid")
and hasattr(e, "field_definition")
and e.field_definition is not None
):
fmap[t][e.name] = fieldmap_data(e, master_lattice)
elif hasattr(mpd, t):
fmap[t][e.name] = getattr(mpd, t)(e)
else:
print("Missing drawings for", t)
return fmap
[docs]
def add_fieldmaps_to_axes(
lattice,
axes,
bounds=None,
sections="All",
fields=["RFCavity", "Solenoid"],
include_labels=True,
verbose=False,
):
"""
Adds fieldmaps to an axes.
"""
max_scale = 0
fmaps = load_elements(
lattice, bounds=bounds, sections=sections, verbose=verbose, types=fields
)
ax1 = axes
ax1rhs = ax1.twinx()
ax = [ax1, ax1rhs]
ylabel = {"RFCavity": "$E_z$ (MV/m)", "Solenoid": "$B_z$ (T)"}
color = {"RFCavity": "green", "Solenoid": "blue"}
for i, section in enumerate(fields):
a = ax[i]
for name, data in fmaps[section].items():
label = f"{section}_{name}"
c = color[section]
# if section == 'cavity':# and not section == 'solenoid':
if section == fields[0]:
max_scale = (
max(abs(data[:, 1]))
if max(abs(data[:, 1])) > max_scale
else max_scale
)
a.plot(*data.T, label=label, color=c)
a.yaxis.label.set_color(c)
a.set_ylabel(ylabel[section])
if len(fields) < 1:
for a in ax:
a.set_yticks([])
data = np.array([[0, 0], [100, 0]])
ax[0].plot(*data.T, color="black")
[docs]
def add_magnets_to_axes(
lattice,
axes,
bounds=None,
sections="All",
magnets=["quadrupole", "dipole", "sextupole", "beam_position_monitor", "screen"],
include_labels=True,
kinetic_energy=None,
verbose=False,
):
"""
Adds magnets to an axes.
"""
max_scale = 0
fmaps = load_elements(
lattice,
bounds=bounds,
sections=sections,
verbose=verbose,
types=magnets,
scale=max_scale,
kinetic_energy=kinetic_energy,
)
ax1 = axes
ax1rhs = ax1.twinx()
ax = [ax1, ax1rhs]
ylabel = {
"dipole": r"$\theta$ (rad)",
"quadrupole": "$K_n$ (T/m)",
} # , "sextupole": "$K_2$ (T/$m^2$)"}
axis = {"dipole": 0, "quadrupole": 1}
color = {
"dipole": "blue",
"quadrupole": "red",
"sextupole": "green",
"beam_position_monitor": "purple",
}
for section, i in axis.items():
a = ax[i]
c = color[section]
a.set_ylabel(ylabel[section])
a.yaxis.label.set_color(c)
for section in color.keys():
if section in fmaps:
for name, (data, c) in fmaps[section].items():
a.fill(*data.T, color=c)
data = np.array([[0, 0], [100, 0]])
ax[0].plot(*data.T, color="black")
if bounds:
ax1.set_xlim(bounds[0], bounds[1])
align.yaxes(ax[0], 0, ax[1], 0, 0.5)
[docs]
def plot_fieldmaps(
lattice,
sections="All",
include_labels=True,
limits=None,
figsize=(12, 4),
fields=["RFCavity", "Solenoid"],
magnets=["quadrupole", "dipole", "beam_position_monitor", "screen"],
**kwargs,
):
"""
Simple fieldmap plot
"""
fig, axes = plt.subplots(figsize=figsize, **kwargs)
add_fieldmaps_to_axes(
lattice,
axes,
bounds=limits,
include_labels=include_labels,
sections=sections,
fields=fields,
# magnets=magnets,
)
[docs]
def plot(
framework_object,
ykeys=["sigma_x", "sigma_y"],
ykeys2=["sigma_z"],
xkey="z",
limits=None,
nice=True,
include_layout=False,
include_labels=True,
include_legend=True,
include_particles=False,
fields=["RFCavity", "Solenoid"],
magnets=[
"quadrupole",
"dipole",
"beam_position_monitor",
"screen",
"wall_current_monitor",
"aperture",
],
grid=False,
ax_top=None,
ax_field_layout=None,
ax_magnet_layout=None,
**kwargs,
):
twiss = framework_object.twiss
twiss.sort()
P = framework_object.beams
# ------------------------------------------------------------
# AXIS CREATION OR AXIS INJECTION
# ------------------------------------------------------------
external_axes = (
ax_top is not None
and (include_layout is False or (ax_field_layout is not None and ax_magnet_layout is not None))
)
if include_layout is not False:
if not external_axes:
if "sharex" not in kwargs:
kwargs["sharex"] = True
fig, all_axis = plt.subplots(
3,
gridspec_kw={"height_ratios": [4, 1, 1]},
subplot_kw=dict(frameon=False),
**kwargs,
)
ax_top = all_axis[0]
ax_field_layout = all_axis[1]
ax_magnet_layout = all_axis[2]
fig.subplots_adjust(hspace=0)
else:
if not external_axes:
fig, ax_top = plt.subplots(**kwargs)
if grid:
ax_top.grid(visible=True, which="major", color="#666666", linestyle="-")
# Collect data keys
if isinstance(ykeys, str):
ykeys = [ykeys]
if ykeys2:
if isinstance(ykeys2, str):
ykeys2 = [ykeys2]
ax_right = ax_top.twinx()
ax_plot = [ax_top, ax_right]
else:
ax_plot = [ax_top]
if len(ykeys) == 1 and not ykeys2:
include_legend = False
# X-axis data
X = twiss.stat(xkey).val
if xkey in twiss_defaults:
X = twissParameter(val=X, **twiss_defaults[xkey])
else:
X = twissParameter(val=X, name=xkey, unit="")
# Apply limits
if limits:
good = np.logical_and(X.val >= limits[0], X.val <= limits[1])
idx = np.where(good)[0]
if len(idx) > 0:
if idx[0] > 0:
good[idx[0] - 1] = True
if idx[-1] + 1 < len(good):
good[idx[-1] + 1] = True
X = X[good]
if min(X.val) > limits[0]:
limits = (min(X.val), limits[1])
if max(X.val) < limits[1]:
limits = (limits[0], max(X.val))
else:
limits = (min(X.val), max(X.val))
good = slice(None)
# Particles
Pnames = []
X_particles = []
if include_particles:
for pname in range(len(P)):
xp = np.mean(getattr(P[pname], xkey))
if limits[0] <= xp <= limits[1]:
Pnames.append(pname)
X_particles.append(xp)
X_particles = np.array(X_particles)
# Units + nice scaling
units_x = X.unit
if nice:
X.val, factor_x, prefix_x = nice_array(X.val)
units_x = prefix_x + units_x
else:
factor_x = 1
# Set x axis label
ax_plot[0].set_xlabel(f"{xkey} ({units_x})")
# Plot ykeys + ykeys2
linestyles = ["solid", "dashed"]
legend_labels = []
line_index = -1
for idx_axis, keys in enumerate([ykeys, ykeys2]):
if not keys:
continue
ax = ax_plot[idx_axis]
linestyle = linestyles[idx_axis]
# Determine units
units_list = []
for key in keys:
Y = twiss.stat(key).val
if key in twiss_defaults:
Y = twissParameter(val=Y, **twiss_defaults[key])
else:
Y = twissParameter(val=Y, name=key, unit="")
units_list.append(Y.unit)
# Unit compatibility
if len(set(units_list)) > 1:
raise ValueError("Incompatible units among ykeys.")
unit = units_list[0]
data_list = [twiss.stat(key).val[good] for key in keys]
labels = [twiss.stat(key).label for key in keys]
if nice:
factor, prefix = nice_scale_prefix(np.ptp(data_list))
unit = prefix + unit
else:
factor = 1
for key, dat, label in zip(keys, data_list, labels):
line_index += 1
color = f"C{line_index}"
for symbol in ["beta", "alpha", "gamma", "sigma"]:
if symbol in label:
label = "$" + label.replace(symbol, '\\' + symbol) + "$"
legend_labels.append(label)
ax.plot(
X.val,
dat / factor,
label=f"{label} ({unit})",
color=color,
linestyle=linestyle,
)
# Particle plots
if len(Pnames) > 0:
Yp = np.array([
np.std(getattr(P[name], key))
if key in P._parameters["data"]
else getattr(P[name], key)
for name in Pnames
])
ax.scatter(X_particles / factor_x, Yp / factor, color=color)
ylabel = []
for l in labels:
for symbol in ["beta", "alpha", "gamma", "sigma"]:
if symbol in l:
l = "$" + l.replace(symbol, '\\' + symbol) + "$"
ylabel.append(f"{l}")
ylabel = ", ".join(ylabel)
ax.set_ylabel(ylabel + f" ({unit})")
# Legend
if include_legend:
handles = []
for ax in ax_plot:
h, _ = ax.get_legend_handles_labels()
handles.extend(h)
ax_plot[0].legend(handles, legend_labels, loc="best")
# Accelerator layout
if include_layout:
add_fieldmaps_to_axes(
framework_object.framework,
ax_field_layout,
bounds=limits,
include_labels=include_labels,
fields=fields,
)
add_magnets_to_axes(
framework_object.framework,
ax_magnet_layout,
bounds=limits,
include_labels=include_labels,
magnets=magnets,
kinetic_energy=list(
zip(twiss.stat("z").val[good], twiss.stat("kinetic_energy").val[good])
),
)
ax_field_layout.set_xlim(ax_top.get_xlim())
ax_magnet_layout.set_xlim(ax_top.get_xlim())
if external_axes and include_layout:
ax_field_layout.sharex(ax_top)
ax_magnet_layout.sharex(ax_top)
# 2. X label placement
if include_layout:
if external_axes:
# external axes: put label only on bottom axis
ax_top.set_xlabel("")
ax_field_layout.set_xlabel("")
ax_magnet_layout.set_xlabel(f"{xkey} ({units_x})")
# show ticks only on bottom axis
ax_top.tick_params(labelbottom=False)
ax_field_layout.tick_params(labelbottom=False)
ax_magnet_layout.tick_params(labelbottom=True)
else:
# internal axes: bottom axis gets the x-label
ax_top.set_xlabel("")
ax_field_layout.set_xlabel("")
ax_magnet_layout.set_xlabel(f"{xkey} ({units_x})")
ax_top.tick_params(labelbottom=False)
ax_field_layout.tick_params(labelbottom=False)
ax_magnet_layout.tick_params(labelbottom=True)
else:
# No layout → single axis case
ax_top.set_xlabel(f"{xkey} ({units_x})")
return ax_top, ax_field_layout, ax_magnet_layout
[docs]
def getattrsplit(self, attr):
attrs = attr.split(".")
for a in attrs:
self = getattr(self, a)
return self
[docs]
def general_plot(
framework_object,
ykeys=[],
ykeys2=[],
xkey="z",
limits=None,
nice=True,
include_layout=False,
include_labels=True,
include_legend=True,
include_particles=False,
fields=["RFCavity", "Solenoid"],
magnets=[
"quadrupole",
"dipole",
"beam_position_monitor",
"screen",
"wall_current_monitor",
"aperture",
],
grid=False,
**kwargs,
):
if include_layout is not False:
fig, all_axis = plt.subplots(2, gridspec_kw={"height_ratios": [4, 1]}, **kwargs)
ax_layout = all_axis[-1]
ax_plot = [all_axis[0]]
else:
fig, all_axis = plt.subplots(**kwargs)
ax_plot = [all_axis]
if grid:
ax_plot[0].grid(b=True, which="major", color="#666666", linestyle="-")
# collect axes
if isinstance(ykeys, str):
ykeys = [ykeys]
if ykeys2:
if isinstance(ykeys2, str):
ykeys2 = [ykeys2]
ax_plot.append(ax_plot[0].twinx())
# Ensure we are using numpy arrays
xdata = getattrsplit(framework_object.twiss, xkey)
ydata = [getattrsplit(framework_object.twiss, y) for y in ykeys]
ydata2 = [getattrsplit(framework_object.twiss, y) for y in ykeys2]
# Split keys
xkey = xkey.split(".")[-1]
ykeys = [yk.split(".")[-1] for yk in ykeys]
ykeys2 = [yk.split(".")[-1] for yk in ykeys2]
# No need for a legend if there is only one plot
if len(ydata) == 1 and not ydata2:
include_legend = False
X = xdata
# Only get the data we need
if limits:
good = np.logical_and(X >= limits[0], X <= limits[1])
idx = list(np.where(good is True)[0])
if len(idx) > 0:
if idx[0] > 0:
good[idx[0] - 1] = True
if (idx[-1] + 1) < len(good):
good[idx[-1] + 1] = True
X = X[good]
if X.min() > limits[0]:
limits[0] = X.min()
if X.max() < limits[1]:
limits[1] = X.max()
else:
limits = X.min(), X.max()
good = slice(None, None, None) # everything
# X axis scaling
units_x = xdata.unit
if nice:
X, factor_x, prefix_x = nice_array(X.val)
units_x = prefix_x + units_x
else:
factor_x = 1
# set all but the layout
for ax in ax_plot:
ax.set_xlim(limits[0] / factor_x, limits[1] / factor_x)
ax.set_xlabel(f"{xkey} ({units_x})")
# Draw for Y1 and Y2
linestyles = ["solid", "dashed"]
ii = -1 # counter for colors
for ix, (d, keys) in enumerate([[ydata, ykeys], [ydata2, ykeys2]]):
if not keys:
continue
ax = ax_plot[ix]
linestyle = linestyles[ix]
# Check that units are compatible
ulist = [dat.unit for dat in d]
if len(ulist) > 1:
for u2 in ulist[1:]:
assert ulist[0] == u2, f"Incompatible units: {ulist[0]} and {u2}"
# String representation
unit = str(ulist[0])
# Data
data = [key.val[good] for key in d]
if nice:
factor, prefix = nice_scale_prefix(np.ptp(data))
unit = prefix + unit
else:
factor = 1
# Make a line and point
symbols = ["beta", "alpha", "gamma", "sigma"]
keys = ["$" + k.replace(symbol, "\\" + symbol) + "$" for symbol in symbols for k in keys if symbol in k]
for key, dat in zip(keys, data):
ii += 1
color = "C" + str(ii)
ax.plot(
X,
dat / factor,
label=f"{key} ({unit})",
color=color,
linestyle=linestyle,
)
ax.set_ylabel(", ".join(keys) + f" ({unit})")
# Collect legend
if include_legend:
lines = []
labels = []
for ax in ax_plot:
a, b = ax.get_legend_handles_labels()
lines += a
labels += b
ax_plot[0].legend(lines, labels, loc="best")
# Layout
if include_layout is not False:
# Gives some space to the top plot
# ax_layout.set_ylim(-1, 1.5)
if xkey == "z":
# ax_layout.set_axis_off()
ax_layout.set_xlim(limits[0], limits[1])
# else:
# ax_layout.set_xlabel('mean_z')
# limits = (0, I.stop)
add_fieldmaps_to_axes(
framework_object.framework,
ax_layout,
bounds=limits,
include_labels=include_labels,
fields=fields,
# magnets=magnets,
)
return plt, fig, all_axis