"""
SIMBA Ocelot Module
Various objects and functions to handle OCELOT lattices and commands. See `Ocelot github`_ for more details.
.. _Ocelot github: https://github.com/ocelot-collab/ocelot
Classes:
- :class:`~simba.Codes.Ocelot.Ocelot.ocelotLattice`: The Ocelot lattice object, used for
converting the :class:`~simba.Framework_objects.frameworkObject` s defined in the
:class:`~simba.Framework_objects.frameworkLattice` into an Ocelot lattice object,
and for tracking through it.
"""
from ...Framework_objects import frameworkLattice, getGrids
from ...Modules import Beams as rbf
from ...Modules.Fields import field
from ...Modules.Twiss.ocelot import save_ocelot_twiss_hdf
from copy import deepcopy
from numpy import array, savez_compressed, linspace, save, interp
import os
from yaml import safe_load
with open(
os.path.dirname(os.path.abspath(__file__)) + "/ocelot_defaults.yaml",
"r",
) as infile:
oceglobal = safe_load(infile)
import lox
from lox.worker.thread import ScatterGatherDescriptor
from typing import Dict, List, Any, ClassVar
from laura.models.diagnostic import DiagnosticElement
[docs]
class ocelotLattice(frameworkLattice):
"""
Class for defining the OCELOT lattice object, used for
converting the :class:`~simba.Framework_objects.frameworkObject`s defined in the
:class:`~simba.Framework_objects.frameworkLattice` into an Ocelot lattice object,
and for tracking through it.
"""
screen_threaded_function: ClassVar[ScatterGatherDescriptor] = (
ScatterGatherDescriptor
)
"""Function for converting all screen outputs from ELEGANT into the SIMBA generic
:class:`~simba.Modules.Beams.beam` object and writing files"""
code: str = "ocelot"
"""String indicating the lattice object type"""
trackBeam: bool = True
"""Flag to indicate whether to track the beam"""
lat_obj: Any = None
"""Lattice object as an Ocelot `MagneticLattice`_
.. _MagneticLattice: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/magnetic_lattice.py
"""
pin: Any = None
"""Initial particle distribution as an Ocelot `ParticleArray`_
.. _ParticleArray: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/beam.py"""
pout: Any = None
"""Final particle distribution as an Ocelot `ParticleArray`_"""
tws: List = None
"""List containing Ocelot `Twiss`_ objects
.. _Twiss: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/beam.py
"""
names: List = None
"""Names of elements in the lattice"""
grids: getGrids = None
"""Class for calculating the required number of space charge grids"""
oceglobal: Dict = {}
"""Global settings for Ocelot, read in from `ocelotLattice.settings["global"]["OCELOTsettings"]` and
`ocelot_defaults.yaml`"""
unit_step: float = 0.01
"""Step for Ocelot `PhysProc`_ objects
.. _PhysProc: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/physics_proc.py
"""
smooth_param: float = 0.01
"""Smoothing parameter"""
lsc: bool = True
"""Flag to enable LSC calculations"""
random_mesh: bool = True
"""Random meshing for space charge calculations"""
nbin_csr: int = 10
"""Number of longitudinal bins for CSR calculations"""
mbin_csr: int = 5
"""Number of macroparticle bins for CSR calculations"""
wake_factor: float = 1.0
"""Multiplication factor for wakefields"""
sigmamin_csr: float = 1e-5
"""Minimum size for CSR calculations"""
wake_sampling: int = 1000
"""Number of samples for wake calculations"""
wake_filter: int = 10
"""Filter parameter for wake calculations"""
particle_definition: str = None
"""Initial particle distribution as a string"""
final_screen: Any = None
"""Final screen object"""
mbi_navi: Any | None = None
"""Physics process for calculating microbunching gain"""
mbi: Dict = {}
"""Dictionary containing settings for microbunching gain calculation"""
ref_s: float = None
"""Reference s position"""
ref_idx: int = None
"""Reference particle index"""
def model_post_init(self, __context):
super().model_post_init(__context)
self.oceglobal = (
self.settings["global"]["OCELOTsettings"]
if "OCELOTsettings" in list(self.settings["global"].keys())
else oceglobal
)
cls = self.__class__
for f in cls.model_fields:
if f in list(self.oceglobal.keys()):
setattr(self, f, self.oceglobal[f])
elif f in self.file_block:
setattr(self, f, self.file_block[f])
if (
"input" in self.file_block
and "particle_definition" in self.file_block["input"]
):
if (
self.file_block["input"]["particle_definition"]
== "initial_distribution"
):
self.particle_definition = "laser"
else:
self.particle_definition = self.file_block["input"][
"particle_definition"
]
else:
self.particle_definition = self.start
self.grids = getGrids()
[docs]
def writeElements(self) -> None:
"""
Create Ocelot objects for all the elements in the lattice and set the
:attr:`~simba.Codes.Ocelot.Ocelot.ocelotLattice.lat_obj` and
:attr:`~simba.Codes.Ocelot.Ocelot.ocelotLattice.names`.
"""
self.lat_obj = self.section.to_ocelot(save=True)
self.names = [str(x) for x in array([lat.id for lat in self.lat_obj.sequence])]
[docs]
def write(self) -> None:
"""
Create the lattice object via :func:`~simba.Codes.Ocelot.Ocelot.ocelotLattice.writeElements`
and save it as a python file to `master_subdir`.
"""
self.writeElements()
[docs]
def preProcess(self) -> None:
"""
Get the initial particle distribution defined in `file_block['input']['prefix']` if it exists.
"""
super().preProcess()
prefix = self.get_prefix()
prefix = prefix if self.trackBeam else prefix + self.particle_definition
self.read_input_file(prefix, self.particle_definition)
self.ref_s = self.global_parameters["beam"].s
self.ref_idx = self.global_parameters["beam"].reference_particle_index
self.hdf5_to_npz(prefix)
[docs]
def hdf5_to_npz(self, prefix: str="", write: bool=True) -> None:
"""
Convert the initial HDF5 particle distribution to Ocelot format and set
:attr:`~simba.Codes.Ocelot.Ocelot.ocelotLattice.pin` accordingly.
Parameters
----------
prefix: str
Prefix for particle file
write: bool
Flag to indicate whether to save the file
"""
from ...Modules.Beams import ocelot as rbf_ocelot
self.pin = rbf_ocelot.particle_group_to_parray(
self.global_parameters["beam"],
s_start=self.ref_s
)
[docs]
def run(self) -> None:
"""
Run the code, and set :attr:`~tws` and :attr:`~pout`
"""
from ocelot.cpbd.track import track
navi = self.navi_setup()
pin = deepcopy(self.pin)
if self.sample_interval > 1:
pin = pin.thin_out(nth=self.sample_interval)
self.tws, self.pout = track(
self.lat_obj,
pin,
navi=navi,
calc_tws=True,
twiss_disp_correction=True,
)
[docs]
def postProcess(self) -> None:
"""
Convert the outputs from Ocelot to HDF5 format and save them to `master_subdir`.
"""
from ocelot.cpbd.io import save_particle_array
super().postProcess()
twsdat = {e: [] for e in self.tws[0].__dict__.keys()}
for t in self.tws:
for k, v in t.__dict__.items():
# Offset the s values to the start of the lattice
if k == "s":
v += self.startObject.physical.start.z
twsdat[k].append(v)
svals = array(self.getSValues(at_entrance=False)) + twsdat["s"][0]
zvals = [a[-1] for a in self.getZValues()]
twsdat['z'] = interp(twsdat["s"], svals, zvals)
save_ocelot_twiss_hdf(
self,
filename=f'{self.global_parameters["master_subdir"]}/{self.objectname}_twiss.oh5',
twiss=twsdat,
)
if self.mbi_navi is not None:
save(
f'{self.global_parameters["master_subdir"]}/{self.objectname}_mbi.dat',
self.mbi_navi.bf,
)
[docs]
def navi_setup(self) -> "Navigator":
"""
Set up the physics processes for Ocelot (i.e. space charge, CSR, wakes etc).
.. _Navigator: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/navi.py
Returns
-------
Navigator
An Ocelot `Navigator`_ object
"""
from ocelot.cpbd.navi import Navigator
from ocelot import Twiss
from .savebeamopenpmd import SaveBeamOpenPMD
from .mbi import MBI
navi_processes = []
navi_locations_start = []
navi_locations_end = []
# settings = self.settings
navi = Navigator(self.lat_obj, unit_step=self.unit_step)
if self.lsc:
lsc = self.physproc_lsc()
navi_processes += [lsc]
navi_locations_start += [self.lat_obj.sequence[0]]
navi_locations_end += [self.lat_obj.sequence[-1]]
space_charge_set = False
csr_set = False
if "charge" in list(self.file_block.keys()):
if (
"space_charge_mode" in list(self.file_block["charge"].keys())
and self.file_block["charge"]["space_charge_mode"].lower() == "3d"
):
gridsize = self.grids.getGridSizes(
(len(self.global_parameters["beam"].x) / self.sample_interval)
)
g1 = self.sc_grid if hasattr(self, "sc_grid") else gridsize
grids = [g1 for _ in range(3)]
sc = self.physproc_sc(grids)
navi_processes += [sc]
navi_locations_start += [self.lat_obj.sequence[0]]
navi_locations_end += [self.lat_obj.sequence[-1]]
space_charge_set = True
if "csr" in list(self.file_block.keys()):
csr, start, end = self.physproc_csr()
for i in range(len(csr)):
navi_processes += [csr[i]]
navi_locations_start += [start[i]]
navi_locations_end += [end[i]]
if self.mbi["set_mbi"]:
self.mbi_navi = MBI(
lattice=self.lat_obj,
lamb_range=list(
linspace(
float(self.mbi["min"]),
float(self.mbi["max"]),
int(self.mbi["nstep"]),
)
),
lsc=space_charge_set,
csr=csr_set,
slices=self.mbi["slices"],
)
# mbi1.step = self.unit_step
self.mbi_navi.navi = deepcopy(navi)
self.mbi_navi.lattice = deepcopy(self.lat_obj)
self.mbi_navi.lsc = True
navi.add_physics_proc(
self.mbi_navi, self.lat_obj.sequence[0], self.lat_obj.sequence[-1]
)
for name, obj in self.elements.items():
fieldstr = None
if "cavity" in obj.hardware_type.lower():
fieldstr = "wakefield_definition"
elif "wake" in obj.hardware_type.lower():
fieldstr = "field_definition"
if fieldstr is not None:
if getattr(obj.simulation, fieldstr) is not None:
wake, w_ind = self.physproc_wake(
name, getattr(obj.simulation, fieldstr), obj.cavity.n_cells
)
navi_processes += [wake]
navi_locations_start += [self.lat_obj.sequence[w_ind]]
navi_locations_end += [self.lat_obj.sequence[w_ind + 1]]
if obj.hardware_type.lower() == "twissmatch":
twsobj = Twiss(
beta_x=obj.simulation.beta_x,
beta_y=obj.simulation.beta_y,
alpha_x=obj.simulation.alpha_x,
alpha_y=obj.simulation.alpha_y,
Dx=obj.simulation.eta_x,
Dy=obj.simulation.eta_y,
Dxp=obj.simulation.eta_xp,
Dyp=obj.simulation.eta_yp,
)
navi_processes += [self.physproc_beamtransform(tws=twsobj)]
navi_locations_start += [self.lat_obj.sequence[self.names.index(name)]]
navi_locations_end += [self.lat_obj.sequence[self.names.index(name)]]
for w in self.screens_and_bpms + self.apertures:
loc = self.lat_obj.sequence[self.names.index(w.name)]
subdir = self.global_parameters["master_subdir"]
navi_processes += [
SaveBeamOpenPMD(
filename=f"{subdir}/{w.name}.openpmd.hdf5",
global_parameters=self.global_parameters,
zstart=w.physical.start.z,
ref_idx=self.ref_idx,
)
]
navi_locations_start += [loc]
navi_locations_end += [loc]
loc = self.lat_obj.sequence[-1]
subdir = self.global_parameters["master_subdir"]
navi_processes += [
SaveBeamOpenPMD(
filename=f"{subdir}/{self.names[-1]}.openpmd.hdf5",
global_parameters=self.global_parameters,
zstart=self.endObject.physical.end.z,
ref_idx=self.ref_idx,
)
]
navi_locations_start += [loc]
navi_locations_end += [loc]
navi.add_physics_processes(
navi_processes, navi_locations_start, navi_locations_end
)
return navi
[docs]
def physproc_lsc(self) -> "LSC":
"""
Get an Ocelot `LSC`_ physics process
.. LSC: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/sc.py
Returns
-------
LSC
The Ocelot LSC PhysProc
"""
from ocelot.cpbd.sc import LSC
lsc = LSC()
lsc.smooth_param = self.smooth_param
return lsc
[docs]
def physproc_sc(self, grids: List[int]) -> "SpaceCharge":
"""
Get an Ocelot `SpaceCharge`_ physics process
.. _SpaceCharge: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/sc.py
Parameters
----------
grids: List[int]
The space charge grid number in x,y,z
Returns
-------
SpaceCharge
The Ocelot SpaceCharge PhysProc
"""
from ocelot.cpbd.sc import SpaceCharge
sc = SpaceCharge(step=1)
sc.nmesh_xyz = grids
sc.random_mesh = self.random_mesh
return sc
[docs]
def physproc_csr(self) -> tuple:
"""
Get Ocelot `CSR`_ physics processes based on the start and end positions provided in `file_block`.
If these are not provided, just include CSR for the entire lattice.
.. _CSR: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/csr.py
Returns
-------
tuple
A list of CSR PhysProcs, and their start and end positions
"""
csrlist = []
stlist = []
enlist = []
from ocelot.cpbd.csr import CSR
if ("start" in list(self.file_block["csr"].keys())) and (
"end" in list(self.file_block["csr"].keys())
):
start = self.file_block["csr"]["start"]
st = [start] if isinstance(start, str) else start
end = self.file_block["csr"]["end"]
en = [end] if isinstance(end, str) else end
for i in range(len(st)):
stelem = self.lat_obj.sequence[self.names.index(st[i])]
enelem = self.lat_obj.sequence[self.names.index(en[i])]
csr = CSR()
csr.n_bin = self.nbin_csr
csr.m_bin = self.mbin_csr
csr.sigma_min = self.sigmamin_csr
csrlist.append(csr)
stlist.append(stelem)
enlist.append(enelem)
else:
csr = CSR()
csr.n_bin = self.nbin_csr
csr.m_bin = self.mbin_csr
csr.sigma_min = self.sigmamin_csr
stlist = [self.lat_obj.sequence[0]]
enlist = [self.lat_obj.sequence[-1]]
return csrlist, stlist, enlist
[docs]
def physproc_wake(
self,
name: str,
loc: field | str,
ncell: int,
) -> tuple:
"""
Get an Ocelot `Wake`_ physics process based on the wakefield provided.
.. _Wake: https://github.com/ocelot-collab/ocelot/blob/master/ocelot/cpbd/wake.py
Parameters
----------
name: str
Name of lattice object associated with the wake
loc: :class:`~simba.Modules.Fields.field` or str
If `field`, then write the field file to ASTRA format
ncell: int
Number of cells, which provides a multiplication factor for the wake
Returns
-------
tuple
A Wake PhysProc, and its index in the lattice
"""
from ocelot.cpbd.wake3D import Wake, WakeTable
if isinstance(loc, field):
loc = loc.write_field_file(code="astra")
subdir = self.global_parameters["master_subdir"]
fname = subdir + '/' + os.path.basename(loc).replace('.hdf5', '.astra')
wake = Wake(
step=100,
w_sampling=self.wake_sampling,
filter_order=self.wake_filter,
)
wake.factor = ncell * self.wake_factor
wake.wake_table = WakeTable(fname)
w_ind = self.names.index(name)
return wake, w_ind