Source code for simba.Codes.Cheetah.Cheetah

"""
SIMBA Cheetah Module

Various objects and functions to handle Cheetah lattices and commands. See `Cheetah github`_ for more details.

    .. _Cheetah github: https://github.com/desy-ml/cheetah

Classes:
    - :class:`~simba.Codes.Cheetah.Cheetah.cheetahLattice`: The Cheetah lattice object, used for
    converting the :class:`~simba.Framework_objects.frameworkObject` s defined in the
    :class:`~simba.Framework_objects.frameworkLattice` into a Cheetah lattice object,
    and for tracking through it.

"""
from torch import Tensor

from ...Framework_objects import frameworkLattice
from ...Modules import Beams as rbf

import os
from yaml import safe_load
from copy import deepcopy
from typing import Dict, Any, ClassVar
import h5py
import lox
from lox.worker.thread import ScatterGatherDescriptor
from laura.models.diagnostic import DiagnosticElement


with open(
    os.path.dirname(os.path.abspath(__file__)) + "/cheetah_defaults.yaml",
    "r",
) as infile:
    cheetahglobal = safe_load(infile)

twiss_keys = (
    "beta_x",
    "beta_y",
    "alpha_x",
    "alpha_y",
    "s",
    "energy",
    "emittance_x",
    "emittance_y",
    "sigma_x",
    "sigma_y",
    "sigma_px",
    "sigma_py",
    "mu_x",
    "mu_y",
    "sigma_tau",
    "sigma_p",
)

[docs] class cheetahLattice(frameworkLattice): """ Class for defining the Cheetah lattice object, used for converting the :class:`~simba.Framework_objects.frameworkObject`s defined in the :class:`~simba.Framework_objects.frameworkLattice` into a Cheetah lattice object, and for tracking through it. """ screen_threaded_function: ClassVar[ScatterGatherDescriptor] = ( ScatterGatherDescriptor ) """Function for converting all screen outputs from ELEGANT into the SIMBA generic :class:`~simba.Modules.Beams.beam` object and writing files""" code: str = "cheetah" """String indicating the lattice object type""" trackBeam: bool = True """Flag to indicate whether to track the beam""" segment: Any | None = None """ Lattice elements arranged into a Cheetah `Segment`_ .. _Segment: https://github.com/desy-ml/cheetah/blob/master/cheetah/accelerator/segment.py """ pin: Any | None = None """Initial particle distribution as a Cheetah `ParticleArray`_ .. _ParticleBeam: https://github.com/desy-ml/cheetah/blob/master/cheetah/particles/particle_beam.py""" pout: Any | None = None """Final particle distribution as a Cheetah `ParticleArray`_""" tws: tuple[Tensor, ...] | Tensor | None = None """Tensor or tuple of Tensors containing Twiss parameters""" cheetahglobal: Dict = {} """Global settings for Cheetah, read in from `cheetahLattice.settings["global"]["Cheetahsettings"]` and `cheetah_defaults.yaml`""" particle_definition: str = None """Initial particle distribution as a string""" ref_s: float = None """Reference s position""" ref_idx: int = None """Reference particle index""" def model_post_init(self, __context): super().model_post_init(__context) self.cheetahglobal = deepcopy(cheetahglobal) if "CHEETAHsettings" in list(self.settings["global"].keys()): for k, v in self.settings["global"]["CHEETAHsettings"].items(): if isinstance(v, Dict): for k1, v1 in v.items(): self.cheetahglobal[k].update({k1: v1}) else: self.cheetahglobal.update({k: v}) if ( "input" in self.file_block and "particle_definition" in self.file_block["input"] ): if ( self.file_block["input"]["particle_definition"] == "initial_distribution" ): self.particle_definition = "laser" else: self.particle_definition = self.file_block["input"][ "particle_definition" ] else: self.particle_definition = self.start
[docs] def writeElements(self) -> bool: """ Create Cheetah objects for all the elements in the lattice and set the :attr:`~simba.Codes.Cheetah.Cheetah.cheetahLattice.segment`. Returns ------- bool True if successful """ self.segment = self.section.to_cheetah(save=True) return True
[docs] def write(self) -> None: """ Create the lattice object via :func:`~simba.Codes.Cheetah.Cheetah.cheetahLattice.writeElements` and save it as a JSON file to `master_subdir`. """ success = self.writeElements() if success: self.segment.to_lattice_json( filepath=f'{self.global_parameters["master_subdir"]}/{self.objectname}.json' )
[docs] def preProcess(self) -> None: """ Get the initial particle distribution defined in `file_block['input']['prefix']` if it exists. """ super().preProcess() prefix = self.get_prefix() prefix = prefix if self.trackBeam else prefix + self.particle_definition self.read_input_file(prefix, self.particle_definition) self.ref_s = self.global_parameters["beam"].s self.ref_idx = self.global_parameters["beam"].reference_particle_index self.hdf5_to_openpmd()
[docs] def hdf5_to_openpmd(self, prefix="", write=True) -> None: """ Convert the initial HDF5 particle distribution to OpenPMD format and set :attr:`~simba.Codes.Cheetah.Cheetah.cheetahLattice.pin` accordingly. Parameters ---------- prefix: str Prefix for particle file write: bool Flag to indicate whether to save the file """ cheetahbeamfilename = self.particle_definition + ".openpmd.hdf5" self.global_parameters["beam"].beam.rematchXPlane(**self.initial_twiss["horizontal"]) self.global_parameters["beam"].beam.rematchYPlane(**self.initial_twiss["vertical"]) self.pin = rbf.beam.write_cheetah_beam_file( self.global_parameters["beam"], cheetahbeamfilename, write=write )
[docs] def run(self) -> None: """ Run the code, and set :attr:`~tws` and :attr:`~pout` """ # navi = self.navi_setup() pin = deepcopy(self.pin) # if self.sample_interval > 1: # pin = pin.thin_out(nth=self.sample_interval) self.pout = self.segment.track(pin) if self.cheetahglobal["save_twiss"]: self.tws = self.segment.get_beam_attrs_along_segment(twiss_keys, pin)
# print("Twiss parameters:", self.tws) @lox.thread(40) def screen_threaded_function(self, scr: DiagnosticElement, outname: str, name: str) -> None: """ Convert output from Cheetah ParticleBeam to HDF5 format Parameters ---------- scr: LAURA DiagnosticElement Screen object outname: str Name of Cheetah beam file name: str Name of element """ from ...Modules.Beams import cheetah as rbf_cheetah beam = rbf.beam() s = 0 try: s = self.elementObjects[name].physical.middle.z except KeyError: s = self.elementObjects[name.replace('_', "-")].physical.middle.z # scr.tau -= self.startObject.physical.middle.z rbf_cheetah.interpret_cheetah_ParticleBeam( beam, scr, zstart=self.startObject.physical.start.z, s=scr.s.numpy(), ref_index=self.ref_idx, ) rbf.openpmd.write_openpmd_beam_file(beam, outname) if name == self.end: self.global_parameters["beam"] = beam
[docs] def postProcess(self) -> None: """ Convert the outputs from Cheetah to HDF5 format and save them to `master_subdir`. """ from cheetah.accelerator import Screen screens = {} for element in self.segment.elements: if isinstance(element, Screen): screens.update({element.name: element.get_read_beam()}) if not isinstance(self.segment.elements[-1], Screen): screens.update({self.end: self.pout}) i = 0 for name, scr in screens.items(): outname = f'{self.global_parameters["master_subdir"]}/{name.replace("_", "-")}.openpmd.hdf5' self.screen_threaded_function.scatter(scr, outname, name) i += 1 self.screen_threaded_function.gather() if self.cheetahglobal["save_twiss"] and self.tws is not None: twsname = f'{self.global_parameters["master_subdir"]}/{self.objectname}_twiss.cheetah.hdf5' with h5py.File(twsname, "w") as f: twsgrp = f.create_group("Twiss") for key, val in zip(twiss_keys, self.tws): twsgrp.create_dataset(key, data=val.numpy())