Source code for simba.Codes.ASTRA.ASTRA

"""
SIMBA ASTRA Module

Various objects and functions to handle ASTRA lattices and commands. See `ASTRA manual`_ for more details.

    .. _ASTRA manual: https://www.desy.de/~mpyflo/Astra_manual/Astra-Manual_V3.2.pdf

Classes:
    - :class:`~simba.Codes.ASTRA.ASTRA.astraLattice`: The ASTRA lattice object, used for
    converting the :class:`~simba.Framework_objects.frameworkObject` s defined in the
    :class:`~simba.Framework_objects.frameworkLattice` into a string representation of
    the lattice suitable for an ASTRA input file.

    - :class:`~simba.Codes.ASTRA.ASTRA.astra_header`: Class for defining the &HEADER portion
    of the ASTRA input file.

    - :class:`~simba.Codes.ASTRA.ASTRA.astra_newrun`: Class for defining the &NEWRUN portion
    of the ASTRA input file.

    - :class:`~simba.Codes.ASTRA.ASTRA.astra_charge`: Class for defining the &CHARGE portion
    of the ASTRA input file.

    - :class:`~simba.Codes.ASTRA.ASTRA.astra_output`: Class for defining the &OUTPUT portion
    of the ASTRA input file.

    - :class:`~simba.Codes.ASTRA.ASTRA.astra_errors`: Class for defining the &ERRORS portion
    of the ASTRA input file.
"""

import os
from warnings import warn
import numpy as np
import lox
from lox.worker.thread import ScatterGatherDescriptor
from typing import ClassVar, Dict, List, Any, Tuple
from pydantic import Field, field_validator, ConfigDict

from ...Framework_objects import frameworkLattice, global_error
from ...FrameworkHelperFunctions import expand_substitution, saveFile
from ...Modules import Beams as rbf
from laura.models.diagnostic import DiagnosticElement
from laura.models.element import PhysicalBaseElement
from laura.models.physical import PhysicalElement
from laura.translator.converters.codes.astra import (
    astra_newrun,
    astra_charge,
    astra_output,
    astra_errors,
)

from ...Modules.units import UnitValue

section_header_text_ASTRA = {
    "cavities": {"header": "CAVITY", "bool": "LEField"},
    "wakefields": {"header": "WAKE", "bool": "LWAKE"},
    "solenoids": {"header": "SOLENOID", "bool": "LBField"},
    "quadrupoles": {"header": "QUADRUPOLE", "bool": "LQuad"},
    "dipoles": {"header": "DIPOLE", "bool": "LDipole"},
    "astra_newrun": {"header": "NEWRUN"},
    "astra_output": {"header": "OUTPUT"},
    "astra_charge": {"header": "CHARGE"},
    "global_error": {"header": "ERROR"},
    "apertures": {"header": "APERTURE", "bool": "LApert"},
}


[docs] class astraLattice(frameworkLattice): """ Class for defining the ASTRA lattice object, used for converting the :class:`~simba.Framework_objects.frameworkObject`s defined in the :class:`~simba.Framework_objects.frameworkLattice` into a string representation of the lattice suitable for an ASTRA input file. """ model_config = ConfigDict(validate_assignment=True) screen_threaded_function: ClassVar[ScatterGatherDescriptor] = ( ScatterGatherDescriptor ) """Function for converting all screen outputs from ASTRA into the SIMBA generic :class:`~simba.Modules.Beams.beam` object and writing files""" code: str = "astra" """String indicating the lattice object type""" allow_negative_drifts: bool = True """Flag to indicate whether negative drifts are allowed""" _bunch_charge: float | None = None """Bunch charge""" _toffset: float | None = None """Time offset of reference particle""" _space_charge_mode: str | None = None headers: Dict = {} """Headers to be included in the ASTRA lattice file""" starting_offset: list[float] = [0.0, 0.0, 0.0] """Initial offset of first element""" starting_rotation: list[float] = [0.0, 0.0, 0.0] """Initial rotation of first element""" zstop: float = None """End z position of lattice""" astra_headers: Dict[str, Any] = Field(default_factory=dict) """Headers for ASTRA input file""" ref_s: float = None """Reference s position""" def model_post_init(self, __context: Any) -> None: super().model_post_init(__context) self.starting_offset = ( eval(expand_substitution(self, self.file_block["starting_offset"])) if "starting_offset" in self.file_block else [0, 0, 0] ) # This calculated the starting rotation based on the input file and the number of dipoles self.starting_rotation = ( [0.0, 0.0, float(-1 * self.startObject.physical.global_rotation.theta)] ) self.starting_rotation = ( eval(expand_substitution(self, str(self.file_block["starting_rotation"]))) if "starting_rotation" in self.file_block else self.starting_rotation ) # Create a "newrun" block if "input" not in self.file_block: self.file_block["input"] = {} if "ASTRAsettings" not in self.globalSettings: self.globalSettings["ASTRAsettings"] = {} newrun_settings = self.file_block["input"] | self.globalSettings["ASTRAsettings"] starting_offset = [a + b for a, b in zip(self.startObject.physical.start, self.starting_offset)] self.section.astra_headers["newrun"] = astra_newrun( starting_offset=starting_offset, starting_rotation=self.starting_rotation, global_parameters=self.global_parameters, input_particle_definition = self.startObject.name, **newrun_settings, ) # If the initial distribution is derived from a generator file, we should use that if ( "input" in self.file_block and "particle_definition" in self.file_block["input"] ): if ( self.file_block["input"]["particle_definition"] == "initial_distribution" ): self.section.astra_headers["newrun"].input_particle_definition = "laser.astra" self.section.astra_headers["newrun"].output_particle_definition = "laser.astra" else: self.section.astra_headers["newrun"].input_particle_definition = self.file_block[ "input" ]["particle_definition"] self.section.astra_headers["newrun"].output_particle_definition = ( self.objectname + ".astra" ) else: self.section.astra_headers["newrun"].input_particle_definition = ( self.start + ".astra" ) self.section.astra_headers["newrun"].output_particle_definition = ( self.objectname + ".astra" ) # Create an "output" block if "output" not in self.file_block: self.file_block["output"] = {} output_settings = self.file_block["output"] | self.globalSettings["ASTRAsettings"] zstart = self.startObject.physical.start.z self.zstop = self.endObject.physical.end.z screens = [e for e in self.section.elements.elements.values() if e.hardware_class == "Diagnostic"] if "zstart" in output_settings: output_settings.pop("zstart") self.section.astra_headers["output"] = astra_output( starting_offset=self.starting_offset, starting_rotation=self.starting_rotation, global_parameters=self.global_parameters, zstart=zstart, zstop=self.zstop, zemit=int((self.zstop - zstart) / 0.01), screens=screens, **output_settings, ) # # Create a "charge" block if "charge" not in self.file_block: self.file_block["charge"] = {} if "charge" not in self.globalSettings: self.globalSettings["charge"] = {} space_charge_dict = self.file_block["charge"] | self.globalSettings["charge"] charge_settings = space_charge_dict | self.globalSettings["ASTRAsettings"] self.section.astra_headers["charge"] = astra_charge( global_parameters=self.global_parameters, **charge_settings, ) # # Create an "error" block if "global_errors" not in self.file_block: self.file_block["global_errors"] = {} if "global_errors" not in self.globalSettings: self.globalSettings["global_errors"] = {} if "global_errors" in self.file_block or "global_errors" in self.globalSettings: globalerror = global_error( objectname=self.objectname + "_global_error", objecttype="global_error", global_parameters=self.global_parameters, ) error_settings = self.file_block["global_errors"] | self.globalSettings["global_errors"] self.section.astra_headers["global_errors"] = astra_errors( element=globalerror, global_parameters=self.global_parameters, **error_settings, ) self.astra_headers = self.section.astra_headers # print 'errors = ', self.file_block, self.headers['global_errors'] @property def space_charge_mode(self) -> str: """ The space charge type for ASTRA, i.e. "2D", "3D". Returns ------- str The space charge type for ASTRA """ return str(self.astra_headers["charge"].space_charge_mode) @space_charge_mode.setter def space_charge_mode(self, mode: str) -> None: """ Sets the space charge mode for the &HEADER object Parameters ---------- mode: str Space charge mode """ self.astra_headers["charge"].space_charge_mode = str(mode) @property def sample_interval(self) -> int: """ Factor by which to reduce the number of particles in the simulation, i.e. every 10th particle. Returns ------- int The sampling interval `n_red` in ASTRA """ return self._sample_interval @sample_interval.setter def sample_interval(self, interval: int) -> None: """ Sets the factor by which to reduce the number of particles in the simulation in the &NEWRUN header, and scales the number of space charge bins in the &CHARGE header accordingly; see :func:`~simba.Codes.ASTRA.ASTRA.astra_newrun.framework_dict`, :func:`~simba.Codes.ASTRA.ASTRA.astra_charge.grid_size`. Parameters ---------- interval: Sampling interval """ # print('Setting new ASTRA sample_interval = ', interval) self._sample_interval = interval self.astra_headers["newrun"].sample_interval = interval self.astra_headers["charge"].sample_interval = interval @property def bunch_charge(self) -> float: """ Bunch charge in coulombs Returns ------- float: Bunch charge """ return self._bunch_charge @bunch_charge.setter def bunch_charge(self, charge: float) -> None: """ Sets the bunch charge for this object and also in :class:`~simba.Codes.ASTRA.ASTRA.astra_newrun`. Parameters ---------- charge: float Bunch charge in coulombs """ # print('Setting new ASTRA sample_interval = ', interval) self._bunch_charge = charge self.astra_headers["newrun"].bunch_charge = charge @property def toffset(self) -> float: """ Get the time offset for the reference particle. Returns ------- float The time offset in seconds """ return self._toffset @toffset.setter def toffset(self, toffset: float) -> None: """ Set the time offset for this object and the :class:`~simba.Codes.ASTRA.ASTRA.astra_newrun` object. Parameters ---------- toffset: float The time offset in seconds """ # print('Setting new ASTRA sample_interval = ', interval) self._toffset = toffset self.astra_headers["newrun"].toffset = 1e9 * toffset
[docs] def write(self) -> None: """ Writes the ASTRA input file from :func:`~simba.Codes.ASTRA.ASTRA.astraLattice.writeElements` to <master_subdir>/<self.objectname>.in. """ code_file = ( self.global_parameters["master_subdir"] + "/" + self.objectname + ".in" ) self.section.astra_headers = self.astra_headers saveFile(code_file, self.section.to_astra()) self.files.append(code_file)
[docs] def preProcess(self) -> None: """ Convert the beam file from the previous lattice section into ASTRA format and set the number of particles based on the input distribution, see :func:`~simba.Codes.ASTRA.ASTRA.astra_newrun.hdf5_to_astra`. """ super().preProcess() prefix = self.get_prefix() astrabeamfilename = self.read_input_file( prefix, self.astra_headers["newrun"].input_particle_definition.replace(".astra", "") ) self.ref_s = self.global_parameters["beam"].s if self.global_parameters["beam"].s is not None else 0 self.astra_headers["newrun"].input_particle_definition = self.hdf5_to_astra() self.astra_headers["charge"].npart = len(self.global_parameters["beam"].x)
@lox.thread def screen_threaded_function( self, objectname: str, scr: DiagnosticElement, cathode: bool, mult: int, sval: float = 0.0, ) -> None: """ Convert output from ASTRA screen to HDF5 format Parameters ---------- objectname: str Name of screen object scr: :class:`~laura.models.diagnostic.DiagnosticElement` Screen object cathode: bool True if beam was emitted from a cathode mult: int Multiplication factor for ASTRA-type filenames sval: float S-position of beam """ return self.astra_to_hdf5(objectname, scr, cathode, mult, sval)
[docs] def get_screen_scaling(self) -> int: """ Determine the screen scaling factor for screens and BPMs Returns ------- int The scaling factor depending on the `master_run_no` parameter """ master_run_no = ( self.global_parameters["run_no"] if "run_no" in self.global_parameters else 1 ) for mult in [100, 1000, 10]: foundscreens = [ self.find_ASTRA_filename(self.objectname, e, master_run_no, mult) for e in self.screens_and_bpms ] if all(foundscreens): return mult return 100
[docs] def postProcess(self) -> None: """ Convert the beam file(s) from the ASTRA output into HDF5 format, see :func:`~simba.Codes.ASTRA.ASTRA.astra_to_hdf5`. """ super().postProcess() cathode = ( self.astra_headers["newrun"].input_particle_definition == "initial_distribution" ) mult = self.get_screen_scaling() svals = np.array(self.getSValues(at_entrance=False)) + self.ref_s zvals = [a[-1] for a in self.getZValues()] for e in self.screens_and_bpms: sval = np.interp(e.middle.z, zvals, svals) self.screen_threaded_function.scatter( scr=e, objectname=self.objectname, cathode=cathode, mult=mult, sval=sval, ) self.screen_threaded_function.gather() endelem = PhysicalBaseElement( name=self.end, hardware_class="", hardware_type="", machine_area="", physical=PhysicalElement(middle=[0, 0, self.zstop]) ) self.astra_to_hdf5(lattice=self.objectname, scr=endelem, cathode=cathode, mult=mult, final=True)
[docs] def astra_to_hdf5( self, lattice: str, scr: DiagnosticElement | PhysicalBaseElement, cathode: bool = False, mult: int = 100, final: bool = False, sval: float = 0.0, ) -> None: """ Convert the ASTRA beam file name to HDF5 format and write the beam file. Parameters ---------- lattice: str Lattice name scr: laura.models.diagnostic.DiagnosticElement LAURA DiagnosticElement cathode: bool True if beam was emitted from a cathode mult: int Multiplication factor for ASTRA-type filenames sval: float S-position of beam """ master_run_no = ( self.global_parameters["run_no"] if "run_no" in self.global_parameters else 1 ) astrabeamfilename = self.find_ASTRA_filename(lattice, scr, master_run_no, mult) if astrabeamfilename is None: warn(f"Screen Error: {lattice}, {scr.physical.middle.z}, {astrabeamfilename}") else: beam = rbf.beam() rbf.astra.read_astra_beam_file( beam, ( os.path.join( self.global_parameters["master_subdir"], astrabeamfilename ) ).strip('"'), normaliseZ=False, ) rbf.hdf5.rotate_beamXZ( beam, -1 * self.starting_rotation[2], preOffset=[0, 0, 0], postOffset=-1 * np.array(self.starting_offset), ) beam.s = UnitValue(sval, units="m") HDF5filename = scr.name + ".openpmd.hdf5" rbf.openpmd.write_openpmd_beam_file( beam, self.global_parameters["master_subdir"] + "/" + HDF5filename, ) if self.global_parameters["delete_tracking_files"]: os.remove( ( os.path.join( self.global_parameters["master_subdir"], astrabeamfilename ) ).strip('"') ) if final: self.global_parameters["beam"] = beam
[docs] def find_ASTRA_filename( self, lattice: str, scr: DiagnosticElement | PhysicalBaseElement, master_run_no: int, mult: int ) -> str | None: """ Determine the ASTRA filename for the screen object. Parameters ---------- lattice: str The name of the lattice scr: laura.models.diagnostic.DiagnosticElement LAURA DiagnosticElement master_run_no: int The run number mult: int Multiplication factor for ASTRA-type output zstart: float Start position of lattice Returns ------- str or None The ASTRA filename for the screen object, or None if the file does not exist. """ for i in [0, -0.001, 0.001]: tempfilename = ( lattice + "." + str(int(round((scr.physical.middle.z + i - self.startObject.physical.start.z) * mult))).zfill(4) + "." + str(master_run_no).zfill(3) ) tempfilenamenozstart = ( lattice + "." + str(int(round((scr.physical.middle.z + i) * mult))).zfill(4) + "." + str(master_run_no).zfill(3) ) tempfilenameend = ( lattice + "." + str(int(round((self.zstop + i - self.startObject.physical.start.z) * mult))).zfill(4) + "." + str(master_run_no).zfill(3) ) tempfilenameendnozstart = ( lattice + "." + str(int(round((self.zstop + i) * mult))).zfill(4) + "." + str(master_run_no).zfill(3) ) for f in [ tempfilename, tempfilenameendnozstart, tempfilenameend, tempfilenamenozstart ]: if os.path.isfile( os.path.join(self.global_parameters["master_subdir"], f) ): return f return None
[docs] def hdf5_to_astra(self) -> str: """ Convert beam input file to ASTRA format and write to `master_subdir`. Returns ------- str: Name of ASTRA beam file """ astrabeamfilename = self.astra_headers["newrun"].output_particle_definition rbf.astra.write_astra_beam_file( self.global_parameters["beam"], self.global_parameters["master_subdir"] + "/" + astrabeamfilename, normaliseZ=False, ) return astrabeamfilename