import os
import re
from shutil import copyfile
import numpy as np
from .Modules.Fields import field
from pydantic import BaseModel
from deepdiff import DeepDiff
from numbers import Number
from laura.models.element import Element
[docs]
def readFile(fname):
with open(fname) as f:
content = f.readlines()
return content
[docs]
def saveFile(filename, lines=[], mode="w"):
stream = open(filename, mode)
for line in lines:
stream.write(line)
stream.close()
[docs]
def findSetting(setting, value, dictionary={}):
"""Looks for a 'value' in 'setting' in dict 'dictionary'"""
settings = []
for li, e in dictionary.items():
if isinstance(e, (dict)) and setting in e.keys() and value == e[setting]:
settings.append([li, e])
return settings
[docs]
def findSettingValue(setting, dictionary={}):
"""Finds the value of a setting in dict 'dictionary'"""
return [k[setting] for k in findSetting(setting, "", dictionary)]
[docs]
def lineReplaceFunction(line, findString, replaceString, i=None):
"""Searches for, and replaces, the string 'findString' with 'replaceString' in 'line'"""
global lineIterator
if findString in line:
if i is not None:
lineIterator += 1
return line.replace("$" + findString + "$", str(replaceString[i]))
else:
return line.replace("$" + findString + "$", str(replaceString))
else:
return line
[docs]
def replaceString(lines=[], findString=None, replaceString=None):
"""Iterates over lines and replaces 'findString' with 'replaceString' which can be a list"""
global lineIterator
if isinstance(replaceString, list):
lineIterator = 0
return [
lineReplaceFunction(line, findString, replaceString, lineIterator)
for line in lines
]
else:
return [lineReplaceFunction(line, findString, replaceString) for line in lines]
[docs]
def chop(expr, delta=1e-8):
"""Performs a chop on small numbers"""
if isinstance(expr, (int, float, complex)):
return 0 if -delta <= expr <= delta else expr
else:
return [chop(x, delta) for x in expr]
[docs]
def chunks(li, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(li), n):
yield li[i : i + n]
[docs]
def dot(a, b) -> float:
return a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
[docs]
def sortByPositionFunction(element):
"""Sort function for element positions"""
return float(element[1]["position_start"][2])
[docs]
def rotationMatrix(theta):
"""Simple 3D rotation matrix"""
c, s = np.cos(theta), np.sin(theta)
return np.matrix([[c, 0, -s], [0, 1, 0], [s, 0, c]])
[docs]
def getParameter(dicts, param, default=0):
"""Returns the values of 'param' in dict 'dict' if it exists, else returns default value. dict can be a list, the most important last."""
param = param.lower()
if isinstance(dicts, list) or isinstance(dicts, tuple):
val = default
for d in dicts:
if isinstance(d, dict) or isinstance(d, dict):
dset = {k.lower(): v for k, v in d.items()}
if param in dset:
val = dset[param]
return val
elif isinstance(dicts, dict) or isinstance(dicts, dict):
dset = {k.lower(): v for k, v in dicts.items()}
# val = dset[param] if param in dset else default
if param in dset:
return dset[param]
else:
# print 'not here! returning ', default
return default
else:
# print 'not here! returning ', default
return default
[docs]
def createOptionalString(paramaterdict, parameter, n=None):
"""Formats ASTRA strings for optional ASTRA parameters"""
val = str(getParameter(paramaterdict, parameter, default=None))
return formatOptionalString(val, parameter, n)
def _rotation_matrix(theta):
return np.array(
[
[np.cos(theta), 0, np.sin(theta)],
[0, 1, 0],
[-1 * np.sin(theta), 0, np.cos(theta)],
]
)
[docs]
def isevaluable(self, s):
try:
eval(s)
return True
except Exception:
return False
[docs]
def path_function(a, b):
# a_drive, a_tail = os.path.splitdrive(os.path.abspath(a))
# b_drive, b_tail = os.path.splitdrive(os.path.abspath(b))
# if (a_drive == b_drive):
# return os.path.relpath(a, b)
# else:
return os.path.abspath(a)
[docs]
def expand_substitution(self, param, subs={}, elements={}, absolute=False):
# print(param)
if isinstance(param, (str)):
subs["master_lattice"] = (
path_function(
self.global_parameters["master_lattice"],
self.global_parameters["master_subdir"],
)
+ "/"
)
subs["master_subdir"] = "./"
regex = re.compile(r"\$(.*)\$")
s = re.search(regex, param)
if s:
if isevaluable(self, s.group(1)) is True:
replaced_str = str(eval(re.sub(regex, str(eval(s.group(1))), param)))
else:
replaced_str = re.sub(regex, s.group(1), param)
for key in subs:
replaced_str = replaced_str.replace(key, subs[key])
if os.path.exists(replaced_str):
replaced_str = path_function(
replaced_str, self.global_parameters["master_subdir"]
).replace("\\", "/")
# print('\tpath exists', replaced_str)
for e in elements.keys():
if e in replaced_str:
print("Element is in string!", e, replaced_str)
return replaced_str
else:
return param
else:
return param
[docs]
def checkValue(self, d, default=None):
if isinstance(d, dict):
if "type" in d and d["type"] == "list":
if "default" in d:
return [
a if a is not None else b for a, b in zip(d["value"], d["default"])
]
else:
if isinstance(d["value"], list):
return [val if val is not None else default for val in d["value"]]
else:
return None
else:
d["value"] = expand_substitution(self, d["value"])
return (
d["value"]
if d["value"] is not None
else d["default"] if "default" in d else default
)
elif isinstance(d, str):
return (
getattr(self, d)
if hasattr(self, d) and getattr(self, d) is not None
else default
)
[docs]
def clean_directory(folder):
for the_file in os.listdir(folder):
file_path = os.path.join(folder, the_file)
try:
if os.path.isfile(file_path):
os.unlink(file_path)
# elif os.path.isdir(file_path): shutil.rmtree(file_path)
except Exception as e:
print("clean_directory error:", e)
[docs]
def list_add(list1, list2):
return [l1 + l2 for l1, l2 in zip(list1, list2)]
[docs]
def symlink(source, link_name):
os_symlink = getattr(os, "symlink", None)
if callable(os_symlink):
try:
os_symlink(source, link_name)
except FileExistsError:
pass
else:
import ctypes
csl = ctypes.windll.kernel32.CreateSymbolicLinkW
csl.argtypes = (ctypes.c_wchar_p, ctypes.c_wchar_p, ctypes.c_uint32)
csl.restype = ctypes.c_ubyte
flags = 1 if os.path.isdir(source) else 0
if csl(link_name, source, flags) == 0:
raise ctypes.WinError()
[docs]
def copylink(source, destination):
try:
copyfile(source, destination)
except Exception as e:
print("copylink error!", e)
pass
[docs]
def convert_numpy_types(v):
if isinstance(v, dict):
return {key: convert_numpy_types(item) for key, item in v.items()}
elif isinstance(v, (np.ndarray, list, tuple)):
try:
return [convert_numpy_types(li) for li in v]
except TypeError:
return float(v)
elif isinstance(v, (np.float64, np.float32, np.float16)):
return float(v)
elif isinstance(
v,
(
np.int_,
np.intc,
np.intp,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
),
):
return int(v)
elif isinstance(v, field):
return convert_numpy_types(v.model_dump())
else:
return v
[docs]
def pydantic_basemodel_dump_computed_fields(self, *args, **kwargs):
# Only include computed fields
computed_keys = {
f for f in self.__pydantic_decorators__.computed_fields.keys()
}
full_dump = BaseModel().model_dump(*args, **kwargs)
return {k: v for k, v in full_dump.items() if k in computed_keys}
[docs]
def normalize(obj):
if isinstance(obj, dict):
return {k: normalize(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [normalize(v) for v in obj]
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.generic): # np.float64, np.int64, etc.
return obj.item()
elif isinstance(obj, (int, float)):
return float(obj) # Normalize int to float
else:
return obj
[docs]
def deepdiff_to_nested(diff: dict) -> dict:
"""
Convert a DeepDiff result (values_changed only)
into a nested dictionary structure.
"""
nested = {}
if 'values_changed' not in diff:
return nested
for path, change in diff['values_changed'].items():
# Strip the "root" prefix and split the path into keys
parts = path.replace("root", "").strip(".")
keys = []
current = ""
in_brackets = False
# Parse keys like ['a']['b'][0]['c'] → ['a','b',0,'c']
for char in parts:
if char == "[":
in_brackets = True
current = ""
elif char == "]":
in_brackets = False
key = current.strip("'\"")
keys.append(int(key) if key.isdigit() else key)
elif in_brackets:
current += char
# Build nested dicts
d = nested
for k in keys[:-1]:
d = d.setdefault(k, {})
d[keys[-1]] = {
"old": change["old_value"],
"new": change["new_value"],
}
return nested
[docs]
def compare_multiple_models(model_pairs: list[tuple[Element, Element]]) -> dict:
"""
Given a list of (old_model, new_model) pairs,
return a nested dictionary of all changes.
"""
all_changes = {}
for old, new in model_pairs:
# Normalize old and new models
old_dump = normalize(old.model_dump())
new_dump = normalize(new.model_dump())
# Log the normalized versions of the models for debugging
# Calculate DeepDiff between normalized models
diff = DeepDiff(old_dump, new_dump, ignore_order=True, significant_digits=10)
# Convert the DeepDiff output into a nested dictionary format
nested_diff = deepdiff_to_nested(diff.to_dict())
# Log the nested difference result for debugging
all_changes[old.name] = nested_diff
return all_changes
[docs]
def set_deep_attr(obj, dotted_path, value):
"""Set nested attribute using a dotted path like 'a.b.c.d'."""
attrs = dotted_path.split('.')
target = obj
for attr in attrs[:-1]:
target = getattr(target, attr)
setattr(target, attrs[-1], value)
[docs]
def flatten_changes_dict(d, parent_key=""):
"""
Flattens nested dict keys into dotted paths.
Returns a list of (dotted_path, value).
"""
items = []
for k, v in d.items():
new_key = f"{parent_key}.{k}" if parent_key else k
if isinstance(v, dict) and not ("old" in v and "new" in v):
items.extend(flatten_changes_dict(v, new_key))
elif isinstance(v, dict) and "new" in v:
# This node contains the actual value diff
items.append((new_key, v["new"]))
else:
# Simple leaf
items.append((new_key, v))
return items