15 Stage Laser-Plasma Accelerator Surrogate
This example models an electron beam accelerated through fifteen stages of laser-plasma accelerators with ideal plasma lenses providing the focusing between stages. For more details, see:
Sandberg R T, Lehe R, Mitchell C E, Garten M, Myers A, Qiang J, Vay J-L and Huebl A. Synthesizing Particle-in-Cell Simulations Through Learning and GPU Computing for Hybrid Particle Accelerator Beamlines. Proc. of Platform for Advanced Scientific Computing (PASC’24), PASC24 Best Paper Award, 2024. DOI:10.1145/3659914.3659937
Sandberg R T, Lehe R, Mitchell C E, Garten M, Qiang J, Vay J-L and Huebl A. Hybrid Beamline Element ML-Training for Surrogates in the ImpactX Beam-Dynamics Code. 14th International Particle Accelerator Conference (IPAC’23), WEPA101, 2023. DOI:10.18429/JACoW-IPAC2023-WEPA101
A schematic with more information can be seen in the figure below:

Fig. 13 Schematic of the 15 stages of laser-plasma accelerators.
The laser-plasma accelerator elements are modeled with neural networks as surrogates.
These networks are trained beforehand.
In this example, pre-trained neural networks are downloaded from a Zenodo archive and saved in the models
directory.
For more about how these neural network surrogate models were created,
see this description of a workflow for training neural networks from WarpX simulation data.
We use a 1 GeV electron beam with initial normalized rms emittance of 1 mm-mrad.
In this test, the initial and final values of \(\sigma_x\), \(\sigma_y\), \(\sigma_t\), \(\epsilon_x\), \(\epsilon_y\), and \(\epsilon_t\) must agree with nominal values.
Run
This example can only be run with Python:
Python script:
python3 run_ml_surrogate_15_stage.py
For MPI-parallel runs, prefix these lines with mpiexec -n 4 ...
or srun -n 4 ...
, depending on the system.
examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py
.#!/usr/bin/env python3
#
# Copyright 2022-2023 ImpactX contributors
# Authors: Ryan Sandberg, Axel Huebl, Chad Mitchell
# License: BSD-3-Clause-LBNL
#
# -*- coding: utf-8 -*-
import argparse
import amrex.space3d as amr
try:
import cupy as cp
cupy_available = True
except ImportError:
cupy_available = False
import numpy as np
import scipy.optimize as opt
from surrogate_model_definitions import surrogate_model
from impactx import (
Config,
CoordSystem,
ImpactX,
ImpactXParIter,
coordinate_transformation,
distribution,
elements,
)
try:
import torch
except ImportError:
print("Warning: Cannot import PyTorch. Skipping test.")
import sys
sys.exit(42) # ImpactX special return code for skipped tests
import zipfile
from urllib import request
parser = argparse.ArgumentParser()
parser.add_argument(
"--num_particles",
"-N",
type=int,
default=100000,
help="number of particles to use in beam",
)
parser.add_argument(
"--N_stages",
"-ns",
type=int,
default=15,
choices=range(1, 16),
help="number of LPA stages to simulate",
)
args = parser.parse_args()
if Config.have_gpu and cupy_available:
array = cp.array
stack = cp.stack
sqrt = cp.sqrt
device = torch.device("cuda")
if Config.gpu_backend == "SYCL":
print("Warning: SYCL GPU backend not yet implemented for Python")
else:
array = np.array
stack = np.stack
sqrt = np.sqrt
device = None
if device is not None:
print(f"device={device}")
else:
print("device set to default, cpu")
N_stage = args.N_stages
tune_by_x_or_y = "x"
npart = args.num_particles
ebeam_lpa_z0 = -107e-6
L_plasma = 0.28
L_transport = 0.03
L_stage_period = L_plasma + L_transport
drift_after_LPA = 43e-6
L_surrogate = abs(ebeam_lpa_z0) + L_plasma + drift_after_LPA
def download_and_unzip(url, data_dir):
request.urlretrieve(url, data_dir)
with zipfile.ZipFile(data_dir, "r") as zip_dataset:
zip_dataset.extractall()
print(
"Downloading trained models from Zenodo.org - this might take a minute...",
flush=True,
)
data_url = "https://zenodo.org/records/10810754/files/models.zip?download=1"
download_and_unzip(data_url, "models.zip")
# It was found that the PyTorch multithreaded defaults interfere with AMReX OpenMP
# when initializing the models or iterating elements:
# https://github.com/AMReX-Codes/pyamrex/issues/322
# https://github.com/BLAST-ImpactX/impactx/issues/773#issuecomment-2585043099
# So we manually set the number of threads to serial (1).
# Torch threading is not a problem with GPUs and might work when MPI is disabled.
# Could also just be a mixing of OpenMP libraries (gomp and llvm omp) when using the
# pre-build PyTorch pip packages.
torch.set_num_threads(1)
model_list = [
surrogate_model(f"models/beam_stage_{stage_i}_model.pt", device=device)
for stage_i in range(N_stage)
]
pp_amrex = amr.ParmParse("amrex")
pp_amrex.add("the_arena_init_size", 0)
pp_amrex.add("the_device_arena_init_size", 0)
sim = ImpactX()
# set numerical parameters and IO control
sim.particle_shape = 2 # B-spline order
sim.space_charge = False
sim.diagnostics = True # benchmarking
sim.slice_step_diagnostics = True
# domain decomposition & space charge mesh
sim.init_grids()
# load a 1 GeV electron beam with an initial
# unnormalized rms emittance of 1 nm
ref_u = 1957
energy_gamma = np.sqrt(1 + ref_u**2)
energy_MeV = 0.510998950 * energy_gamma # reference energy
bunch_charge_C = 10.0e-15 # used with space charge
# reference particle
ref = sim.particle_container().ref_particle()
ref.set_charge_qe(-1.0).set_mass_MeV(0.510998950).set_kin_energy_MeV(energy_MeV)
ref.z = ebeam_lpa_z0
pc = sim.particle_container()
distr = distribution.Gaussian(
lambdaX=0.75e-6,
lambdaY=0.75e-6,
lambdaT=0.1e-6,
lambdaPx=1.33 / energy_gamma,
lambdaPy=1.33 / energy_gamma,
lambdaPt=1e-8,
muxpx=0.0,
muypy=0.0,
mutpt=0.0,
)
sim.add_particles(bunch_charge_C, distr, npart)
n_slice = 1
class LPASurrogateStage(elements.Programmable):
def __init__(self, stage_i, surrogate_model, surrogate_length, stage_start):
elements.Programmable.__init__(self)
self.stage_i = stage_i
self.surrogate_model = surrogate_model
self.surrogate_length = surrogate_length
self.stage_start = stage_start
self.push = self.surrogate_push
self.ds = surrogate_length
def surrogate_push(self, pc, step, period):
ref_part = pc.ref_particle()
ref_z_i = ref_part.z
ref_z_i_LPA = ref_z_i - self.stage_start
ref_z_f = ref_z_i + self.surrogate_length
ref_part_tensor = torch.tensor(
[
ref_part.x,
ref_part.y,
ref_z_i_LPA,
ref_part.px,
ref_part.py,
ref_part.pz,
],
device=device,
dtype=torch.float64,
)
ref_beta_gamma = torch.sqrt(torch.sum(ref_part_tensor[3:] ** 2))
ref_beta_gamma = ref_beta_gamma.to(device)
with torch.no_grad():
ref_part_model_final = self.surrogate_model(ref_part_tensor)
ref_uz_f = ref_part_model_final[5]
ref_beta_gamma_final = ref_uz_f
ref_part_final = torch.tensor(
[0, 0, ref_z_f, 0, 0, ref_uz_f], device=device, dtype=torch.float64
)
coordinate_transformation(pc, CoordSystem.t)
for lvl in range(pc.finest_level + 1):
for pti in ImpactXParIter(pc, level=lvl):
soa = pti.soa()
real_arrays = soa.get_real_data()
x = array(real_arrays[0], copy=False)
y = array(real_arrays[1], copy=False)
t = array(real_arrays[2], copy=False)
px = array(real_arrays[3], copy=False)
py = array(real_arrays[4], copy=False)
pt = array(real_arrays[5], copy=False)
data_arr = torch.tensor(
stack([x, y, t, px, py, pt], axis=1),
device=device,
dtype=torch.float64,
)
data_arr[:, 0] += ref_part.x
data_arr[:, 1] += ref_part.y
data_arr[:, 2] += ref_z_i_LPA
data_arr[:, 3:] *= ref_beta_gamma
data_arr[:, 3] += ref_part.px
data_arr[:, 4] += ref_part.py
data_arr[:, 5] += ref_part.pz
with torch.no_grad():
data_arr_post_model = self.surrogate_model(data_arr)
# z += stage start
data_arr_post_model[:, 2] += self.stage_start
# back to ref particle coordinates
for ii in range(3):
data_arr_post_model[:, ii] -= ref_part_final[ii]
data_arr_post_model[:, 3 + ii] -= ref_part_final[3 + ii]
data_arr_post_model[:, 3 + ii] /= ref_beta_gamma_final
x[:] = array(data_arr_post_model[:, 0])
y[:] = array(data_arr_post_model[:, 1])
t[:] = array(data_arr_post_model[:, 2])
px[:] = array(data_arr_post_model[:, 3])
py[:] = array(data_arr_post_model[:, 4])
pt[:] = array(data_arr_post_model[:, 5])
# TODO this part needs to be corrected for general geometry
# where the initial vector might not point in z
# and even if it does, bending elements may change the direction
ref_part.x = ref_part_final[0]
ref_part.y = ref_part_final[1]
ref_part.z = ref_part_final[2]
ref_gamma = torch.sqrt(1 + ref_beta_gamma_final**2)
ref_part.px = ref_part_final[3]
ref_part.py = ref_part_final[4]
ref_part.pz = ref_part_final[5]
ref_part.pt = -ref_gamma
ref_part.s += self.surrogate_length
ref_part.t += self.surrogate_length
coordinate_transformation(pc, CoordSystem.s)
## Done!
L_transport = 0.03
L_lens = 0.003
L_focal = 0.5 * L_transport
L_drift = 0.5 * (L_transport - L_lens)
K = np.sqrt(2.0 / L_focal / L_lens)
Kt = 1e-11 # number chosen arbitrarily since 0 isn't allowed
dL = 0
L_drift_minus_surrogate = L_drift
L_drift_1 = L_drift - drift_after_LPA - dL
L_drift_before_2nd_stage = abs(ebeam_lpa_z0)
L_drift_2 = L_drift - L_drift_before_2nd_stage + dL
def get_lattice_element_iter(sim, j):
assert 0 <= j < len(sim.lattice), (
f"Argument j must be a nonnegative integer satisfying 0 <= j < {len(sim.lattice)}, not {j}"
)
i = 0
lat_it = sim.lattice.__iter__()
next(lat_it)
while i != j:
next(lat_it)
i += 1
return lat_it
def lens_eqn(k, lens_length, alpha, beta, gamma):
return np.tan(k * lens_length) + 2 * alpha / (k * beta - gamma / k)
k_list = []
class UpdateConstF(elements.Programmable):
def __init__(self, sim, stage_i, lattice_index, x_or_y):
elements.Programmable.__init__(self)
self.sim = sim
self.stage_i = stage_i
self.lattice_index = lattice_index
self.x_or_y = x_or_y
self.push = self.set_lens
def set_lens(self, pc, step, period):
# get envelope parameters
rbc = pc.reduced_beam_characteristics()
alpha = rbc[f"alpha_{self.x_or_y}"]
beta = rbc[f"beta_{self.x_or_y}"]
gamma = (1 + alpha**2) / beta
# solve for k_new
sol = opt.root_scalar(
lens_eqn, bracket=[100, 300], args=(L_lens, alpha, beta, gamma)
)
k_new = sol.root
# set lens
self_it = get_lattice_element_iter(self.sim, self.lattice_index)
following_lens = next(self_it)
k_list.append(k_new)
following_lens.kx = k_new
following_lens.ky = k_new
lpa_stages = []
for i in range(N_stage):
lpa = LPASurrogateStage(i, model_list[i], L_surrogate, L_stage_period * i)
lpa.nslice = n_slice
lpa.ds = L_surrogate
lpa.threadsafe = False
lpa_stages.append(lpa)
monitor = elements.BeamMonitor("monitor", backend="h5")
for i in range(N_stage):
sim.lattice.extend(
[
monitor,
lpa_stages[i],
]
)
if i != N_stage - 1:
sim.lattice.extend(
[
monitor,
elements.Drift(ds=L_drift_1),
monitor,
UpdateConstF(
sim=sim, stage_i=i, lattice_index=5 + 9 * i, x_or_y=tune_by_x_or_y
),
elements.ConstF(ds=L_lens, kx=K, ky=K, kt=Kt),
monitor,
elements.Drift(ds=L_drift_2),
]
)
sim.lattice.extend([monitor])
sim.track_particles()
sim.finalize()
This script requires some utility code for using the neural networks that is provided here:
Script surrogate_model_definitions.py
examples/pytorch_surrogate_model/surrogate_model_definitions.py
.#!/usr/bin/env python3
#
# Copyright 2022-2023 ImpactX contributors
# Authors: Ryan Sandberg, Axel Huebl
# License: BSD-3-Clause-LBNL
#
# -*- coding: utf-8 -*-
from enum import Enum
try:
import torch
from torch import nn
except ImportError:
print("Warning: Cannot import PyTorch. Skipping test.")
import sys
sys.exit(0)
class Activation(Enum):
"""
Activation class provides an enumeration type for the supported activation layers
"""
ReLU = 1
Tanh = 2
PReLU = 3
Sigmoid = 4
def get_enum_type(type_to_test, EnumClass):
"""
Returns the enumeration type associated to type_to_test in EnumClass
Parameters
----------
type_to_test: EnumClass, int or str
object whose Enum class is to be obtained
EnumClass: Enum class
Enum class to test
"""
if isinstance(type_to_test, EnumClass): ## Useful ?
return type_to_test
if isinstance(type_to_test, int):
return EnumClass(type_to_test)
if isinstance(type_to_test, str):
return getattr(EnumClass, type_to_test)
class ConnectedNN(nn.Module):
"""
ConnectedNN is a class of fully connected neural networks
"""
def __init__(self, layers, device=None):
super().__init__()
self.stack = nn.Sequential(*layers)
if device is not None:
self.to(device)
def forward(self, x):
return self.stack(x)
class OneActNN(ConnectedNN):
"""
OneActNN is class of fully connected neural networks admitting only one activation function
"""
def __init__(self, n_in, n_out, n_hidden_nodes, n_hidden_layers, act, device=None):
self.n_in = n_in
self.n_out = n_out
self.n_hidden_layers = n_hidden_layers
self.n_hidden_nodes = n_hidden_nodes
self.act = act
layers = [nn.Linear(self.n_in, self.n_hidden_nodes)]
for ii in range(self.n_hidden_layers):
if self.act is Activation.ReLU:
layers += [nn.ReLU()]
if self.act is Activation.Tanh:
layers += [nn.Tanh()]
if self.act is Activation.PReLU:
layers += [nn.PReLU()]
if self.act is Activation.Sigmoid:
layers += [nn.Sigmoid()]
if ii < self.n_hidden_layers - 1:
layers += [nn.Linear(self.n_hidden_nodes, self.n_hidden_nodes)]
layers += [nn.Linear(self.n_hidden_nodes, self.n_out)]
super().__init__(layers, device)
class surrogate_model:
"""
Extend the functionality of the OneActNN class
This class is meant to act as a wrapper for the OneActNN class.
It provides a `__call__` function that normalizes input and returns dimensional output.
"""
def __init__(self, model_file, device=None):
self.device = device
if device is None:
model_dict = torch.load(model_file, map_location="cpu", weights_only=False)
else:
model_dict = torch.load(model_file, map_location=device, weights_only=False)
self.source_means = torch.tensor(
model_dict["source_means"], device=self.device, dtype=torch.float64
)
self.target_means = torch.tensor(
model_dict["target_means"], device=self.device, dtype=torch.float64
)
self.source_stds = torch.tensor(
model_dict["source_stds"], device=self.device, dtype=torch.float64
)
self.target_stds = torch.tensor(
model_dict["target_stds"], device=self.device, dtype=torch.float64
)
n_in = model_dict["model_state_dict"]["stack.0.weight"].shape[1]
final_layer_key = list(model_dict["model_state_dict"].keys())[-1]
n_out = model_dict["model_state_dict"][final_layer_key].shape[0]
n_hidden_nodes = model_dict["model_state_dict"]["stack.0.weight"].shape[0]
activation_type = model_dict["activation"]
activation = get_enum_type(activation_type, Activation)
if "n_hidden_layers" in model_dict.keys():
n_hidden_layers = model_dict["n_hidden_layers"]
else:
if activation is Activation.PReLU:
n_hidden_layers = int(
(len(model_dict["model_state_dict"].keys()) - 2) / 3
)
else:
n_hidden_layers = int(
len(model_dict["model_state_dict"].keys()) / 2 - 1
)
self.neural_network = OneActNN(
n_in=n_in,
n_out=n_out,
n_hidden_nodes=n_hidden_nodes,
n_hidden_layers=n_hidden_layers,
act=activation,
device=device,
)
self.neural_network.load_state_dict(model_dict["model_state_dict"])
self.neural_network.eval()
def __call__(self, data_arr):
data_arr -= self.source_means
data_arr /= self.source_stds
with torch.no_grad():
data_arr_post_model = self.neural_network(data_arr.float()).double()
data_arr_post_model *= self.target_stds
data_arr_post_model += self.target_means
return data_arr_post_model
Analyze
We run the following script to analyze correctness:
Script analyze_ml_surrogate_15_stage.py
examples/pytorch_surrogate_model/analyze_ml_surrogate_15_stage.py
.#!/usr/bin/env python3
#
# Copyright 2022-2023 ImpactX contributors
# Authors: Ryan Sandberg, Axel Huebl, Chad Mitchell
# License: BSD-3-Clause-LBNL
#
# -*- coding: utf-8 -*-
import numpy as np
import openpmd_api as io
from scipy.stats import moment
def get_moments(beam):
"""Calculate standard deviations of beam position & momenta
and emittance values
Returns
-------
sigx, sigy, sigt, emittance_x, emittance_y, emittance_t
"""
sigx = moment(beam["position_x"], moment=2) ** 0.5 # variance -> std dev.
sigpx = moment(beam["momentum_x"], moment=2) ** 0.5
sigy = moment(beam["position_y"], moment=2) ** 0.5
sigpy = moment(beam["momentum_y"], moment=2) ** 0.5
sigt = moment(beam["position_t"], moment=2) ** 0.5
sigpt = moment(beam["momentum_t"], moment=2) ** 0.5
epstrms = beam.cov(ddof=0)
emittance_x = (sigx**2 * sigpx**2 - epstrms["position_x"]["momentum_x"] ** 2) ** 0.5
emittance_y = (sigy**2 * sigpy**2 - epstrms["position_y"]["momentum_y"] ** 2) ** 0.5
emittance_t = (sigt**2 * sigpt**2 - epstrms["position_t"]["momentum_t"] ** 2) ** 0.5
return (sigx, sigy, sigt, emittance_x, emittance_y, emittance_t)
# initial/final beam
series = io.Series("diags/openPMD/monitor.h5", io.Access.read_only)
last_step = list(series.iterations)[-1]
initial = series.iterations[1].particles["beam"].to_df()
final = series.iterations[last_step].particles["beam"].to_df()
# compare number of particles
num_particles = 100000
assert num_particles == len(initial)
assert num_particles == len(final)
print("Initial Beam:")
sigx, sigy, sigt, emittance_x, emittance_y, emittance_t = get_moments(initial)
print(f" sigx={sigx:e} sigy={sigy:e} sigt={sigt:e}")
print(
f" emittance_x={emittance_x:e} emittance_y={emittance_y:e} emittance_t={emittance_t:e}"
)
atol = 0.0 # ignored
rtol = num_particles**-0.5 # from random sampling of a smooth distribution
print(f" rtol={rtol} (ignored: atol~={atol})")
assert np.allclose(
[sigx, sigy, sigt, emittance_x, emittance_y],
[
7.494325e-07,
7.478916e-07,
9.976192e-08,
5.070297e-10,
5.080007e-10,
],
rtol=rtol,
atol=atol,
)
atol = 1.0e-6
print(f" atol~={atol}")
assert np.allclose([emittance_t], [0.0], atol=atol)
print("")
print("Final Beam:")
sigx, sigy, sigt, emittance_x, emittance_y, emittance_t = get_moments(final)
print(f" sigx={sigx:e} sigy={sigy:e} sigt={sigt:e}")
print(
f" emittance_x={emittance_x:e} emittance_y={emittance_y:e} emittance_t={emittance_t:e}"
)
atol = 0.0 # ignored
rtol = num_particles**-0.5 # from random sampling of a smooth distribution
print(f" rtol={rtol} (ignored: atol~={atol})")
assert np.allclose(
[sigx, sigy, sigt, emittance_x, emittance_y, emittance_t],
[
1.590999e-07,
1.634865e-07,
1.030930e-07,
5.031797e-12,
5.242205e-12,
2.049623e-11,
],
rtol=rtol,
atol=atol,
)
Visualize
You can run the following script to visualize the beam evolution over time:
Script visualize_ml_surrogate_15_stage.py
examples/pytorch_surrogate_model/visualize_ml_surrogate_15_stage.py
.#!/usr/bin/env python3
#
# Copyright 2022-2023 ImpactX contributors
# Authors: Ryan Sandberg, Axel Huebl, Chad Mitchell
# License: BSD-3-Clause-LBNL
#
# -*- coding: utf-8 -*-
import argparse
import glob
import re
import numpy as np
import openpmd_api as io
import pandas as pd
from matplotlib import pyplot as plt
from scipy.constants import c, e, m_e
def read_all_files(file_pattern):
"""Read in all CSV files from each MPI rank (and potentially OpenMP
thread). Concatenate into one Pandas dataframe.
Returns
-------
pandas.DataFrame
"""
return pd.concat(
(
pd.read_csv(filename, delimiter=r"\s+")
for filename in glob.glob(file_pattern)
),
axis=0,
ignore_index=True,
).set_index("id")
def read_file(file_pattern):
for filename in glob.glob(file_pattern):
df = pd.read_csv(filename, delimiter=r"\s+")
if "step" not in df.columns:
step = int(re.findall(r"[0-9]+", filename)[0])
df["step"] = step
yield df
def read_time_series(file_pattern):
"""Read in all CSV files from each MPI rank (and potentially OpenMP
thread). Concatenate into one Pandas dataframe.
Returns
-------
pandas.DataFrame
"""
return pd.concat(
read_file(file_pattern),
axis=0,
ignore_index=True,
) # .set_index('id')
from enum import Enum
class TCoords(Enum):
REF = 1
GLOBAL = 2
def to_t(
ref_pz, ref_pt, data_arr_s, ref_z=None, coord_type=TCoords.REF
): # x, y, t, dpx, dpy, dpt):
"""Change to fixed t coordinates
Parameters
---
ref_pz: float, reference particle momentum in z
ref_pt: float, reference particle pt = -gamma
data_arr_s: Nx6 array-like structure containing fixed-s particle coordinates
ref_z: if transforming to global coordinates
coord_type: TCoords enum, (default is in ref coordinates) whether to get particle data relative to reference coordinate or in the global frame
"""
if type(data_arr_s) is pd.core.frame.DataFrame:
dx = data_arr_s["position_x"]
dy = data_arr_s["position_y"]
dt = data_arr_s["position_t"]
dpx = data_arr_s["momentum_x"]
dpy = data_arr_s["momentum_y"]
dpt = data_arr_s["momentum_t"]
elif type(data_arr_s) is np.ndarray:
assert data_arr_s.shape[1] == 6, (
f"data_arr_s.shape={data_arr_s.shape} but data_arr_s must be an Nx6 array"
)
dx, dy, dt, dpx, dpy, dpt = data_arr_s.T
else:
raise Exception(
f"Incompatible input type {type(data_arr_s)} for data_arr_s, must be pandas DataFrame or Nx6 array-like object"
)
dx += ref_pz * dpx * dt / (ref_pt + ref_pz * dpt)
dy += ref_pz * dpy * dt / (ref_pt + ref_pz * dpt)
pz = np.sqrt(
-1 + (ref_pt + ref_pz * dpt) ** 2 - (ref_pz * dpx) ** 2 - (ref_pz * dpy) ** 2
)
dt *= pz / (ref_pt + ref_pz * dpt)
if type(data_arr_s) is pd.core.frame.DataFrame:
data_arr_s["momentum_t"] = pz - ref_pz
dpt = data_arr_s["momentum_t"]
else:
dpt[:] = pz - ref_pz
if coord_type is TCoords.REF:
print("applying reference normalization")
dpt /= ref_pz
elif coord_type is TCoords.GLOBAL:
assert ref_z is not None, (
"Reference particle z coordinate is required to transform to global coordinates"
)
print("target global coordinates")
dt += ref_z
dpx *= ref_pz
dpy *= ref_pz
dpt += ref_pz
# data_arr_t = np.column_stack([xt,yt,z,dpx,dpy,dpz])
return # modifies data_arr_s in place
def plot_beam_df(
beam_at_step,
axT,
unit=1e6,
unit_z=1e3,
unit_label=r"$\mu$m",
unit_z_label="mm",
alpha=1.0,
cmap=None,
color="k",
size=0.1,
t_offset=0.0,
label=None,
z_ticks=None,
):
ax = axT[0][0]
ax.scatter(
beam_at_step.position_x.multiply(unit),
beam_at_step.position_y.multiply(unit),
c=color,
s=size,
alpha=alpha,
cmap=cmap,
)
ax.set_xlabel(r"x (%s)" % unit_label)
ax.set_ylabel(r"y (%s)" % unit_label)
ax.axes.ticklabel_format(axis="both", style="sci", scilimits=(-2, 2))
###########
ax = axT[0][1]
ax.scatter(
beam_at_step.position_t.multiply(unit_z) - t_offset,
beam_at_step.position_x.multiply(unit),
c=color,
s=size,
alpha=alpha,
cmap=cmap,
)
ax.set_xlabel(r"%s" % unit_z_label)
ax.set_ylabel(r"x (%s)" % unit_label)
ax.axes.ticklabel_format(axis="y", style="sci", scilimits=(-2, 2))
if z_ticks is not None:
ax.set_xticks(z_ticks)
###########
ax = axT[0][2]
ax.scatter(
beam_at_step.position_t.multiply(unit_z) - t_offset,
beam_at_step.position_y.multiply(unit),
c=color,
s=size,
alpha=alpha,
cmap=cmap,
)
ax.set_xlabel(r"%s" % unit_z_label)
ax.set_ylabel(r"y (%s)" % unit_label)
ax.axes.ticklabel_format(axis="y", style="sci", scilimits=(-2, 2))
if z_ticks is not None:
ax.set_xticks(z_ticks)
############
##########
ax = axT[1][0]
ax.scatter(
beam_at_step.momentum_x,
beam_at_step.momentum_y,
c=color,
s=size,
alpha=alpha,
cmap=cmap,
)
ax.set_xlabel("px")
ax.set_ylabel("py")
ax.axes.ticklabel_format(axis="both", style="sci", scilimits=(-2, 2))
##########
ax = axT[1][1]
ax.scatter(
beam_at_step.momentum_t,
# beam_at_step.position_t.multiply(unit_z)-t_offset,
beam_at_step.momentum_x,
c=color,
s=size,
alpha=alpha,
cmap=cmap,
)
ax.set_xlabel("pt")
# ax.set_xlabel(r'%s'%unit_z_label)
ax.set_ylabel("px")
ax.axes.ticklabel_format(axis="both", style="sci", scilimits=(-2, 2))
##########
ax = axT[1][2]
ax.scatter(
beam_at_step.momentum_t,
# beam_at_step.position_t.multiply(unit_z)-t_offset,
beam_at_step.momentum_y,
c=color,
s=size,
alpha=alpha,
label=label,
cmap=cmap,
)
if label is not None:
ax.legend()
# ax.set_xlabel(r'%s'%unit_z_label)
ax.set_xlabel("pt")
ax.set_ylabel("py")
ax.axes.ticklabel_format(axis="both", style="sci", scilimits=(-2, 2))
############
############
##########
ax = axT[2][0]
ax.scatter(
beam_at_step.position_x.multiply(unit),
beam_at_step.momentum_x,
c=color,
s=size,
alpha=alpha,
cmap=cmap,
)
ax.set_xlabel(r"x (%s)" % unit_label)
ax.set_ylabel("px")
ax.axes.ticklabel_format(axis="both", style="sci", scilimits=(-2, 2))
############
ax = axT[2][1]
ax.scatter(
beam_at_step.position_y.multiply(unit),
beam_at_step.momentum_y,
c=color,
s=size,
alpha=alpha,
cmap=cmap,
)
ax.set_xlabel(r"y (%s)" % unit_label)
ax.set_ylabel("py")
ax.axes.ticklabel_format(axis="both", style="sci", scilimits=(-2, 2))
################
ax = axT[2][2]
ax.scatter(
beam_at_step.position_t.multiply(unit_z) - t_offset,
beam_at_step.momentum_t,
c=color,
s=size,
alpha=alpha,
cmap=cmap,
)
ax.set_xlabel(r"%s" % unit_z_label)
ax.set_ylabel("pt")
ax.axes.ticklabel_format(axis="y", style="sci", scilimits=(-2, 2))
if z_ticks is not None:
ax.set_xticks(z_ticks)
plt.tight_layout()
# done
# options to run this script
parser = argparse.ArgumentParser(description="Plot the ML surrogate benchmark.")
parser.add_argument(
"--save-png", action="store_true", help="non-interactive run: save to PNGs"
)
parser.add_argument(
"--num-stages", "-n", type=int, default=15, help="num stages to plot"
)
parser.add_argument(
"--stages_to_plot", "-s", type=int, help="num stages to plot", nargs="*"
)
args = parser.parse_args()
impactx_surrogate_ref_particle = read_time_series("diags/ref_particle.*")
impactx_surrogate_reduced_diags = read_time_series(
"diags/reduced_beam_characteristics.*"
)
ref_gamma = impactx_surrogate_ref_particle["gamma"]
beam_gamma = (
ref_gamma
- impactx_surrogate_reduced_diags["pt_mean"]
* impactx_surrogate_ref_particle["beta_gamma"]
)
beam_u = np.sqrt(beam_gamma**2 - 1)
emit_x = impactx_surrogate_reduced_diags["emittance_x"]
emit_nx = emit_x * beam_u
emit_y = impactx_surrogate_reduced_diags["emittance_y"]
emit_ny = emit_y * beam_u
ix_slice = [0] + [2 + 9 * i for i in range(args.num_stages)]
############# plot moments ##############
fig, axT = plt.subplots(2, 2, figsize=(10, 8))
ymarker = "^"
######### emittance ##########
ax = axT[0][0]
scale = 1e6
ax.plot(
impactx_surrogate_ref_particle["s"][ix_slice],
emit_nx[ix_slice] * scale,
"bo",
label="x",
)
ax.plot(
impactx_surrogate_ref_particle["s"][ix_slice],
emit_ny[ix_slice] * scale,
"r",
marker=ymarker,
linestyle="None",
label="y",
)
ax.legend()
ax.set_xlabel("s (m)")
ax.set_ylabel(r"emittance (mm-mrad)")
######### energy ##########
ax = axT[0][1]
scale = m_e * c**2 / e * 1e-9
ax.plot(
impactx_surrogate_ref_particle["s"][ix_slice],
beam_gamma[ix_slice] * scale,
"go",
)
ax.set_xlabel("s (m)")
ax.set_ylabel(r"mean energy (GeV)")
######### width ##########
ax = axT[1][0]
scale = 1e6
ax.plot(
impactx_surrogate_ref_particle["s"][ix_slice],
impactx_surrogate_reduced_diags["sig_x"][ix_slice] * scale,
"bo",
label="x",
)
ax.plot(
impactx_surrogate_ref_particle["s"][ix_slice],
impactx_surrogate_reduced_diags["sig_y"][ix_slice] * scale,
"r",
marker=ymarker,
linestyle="None",
label="y",
)
ax.legend()
ax.set_xlabel("s (m)")
ax.set_ylabel(r"beam width ($\mu$m)")
######### divergence ##########
ax = axT[1][1]
scale = 1e3
ax.semilogy(
impactx_surrogate_ref_particle["s"][ix_slice],
impactx_surrogate_reduced_diags["sig_px"][ix_slice] * scale,
"bo",
label="x",
)
ax.semilogy(
impactx_surrogate_ref_particle["s"][ix_slice],
impactx_surrogate_reduced_diags["sig_py"][ix_slice] * scale,
"r",
marker=ymarker,
linestyle="None",
label="y",
)
ax.legend()
ax.set_xlabel("s (m)")
ax.set_ylabel(r"divergence (mrad)")
plt.tight_layout()
if args.save_png:
plt.savefig("lpa_ml_surrogate_moments.png")
else:
plt.show()
######## plot phase spaces ###########
beam_impactx_surrogate_series = io.Series(
"diags/openPMD/monitor.h5", io.Access.read_only
)
impactx_surrogate_steps = list(beam_impactx_surrogate_series.iterations)
millimeter = 1.0e3
micron = 1.0e6
N_stage = args.num_stages
impactx_stage_end_steps = [1] + [3 + 9 * i for i in range(N_stage)]
ise = impactx_stage_end_steps
# initial
step = 1
beam_at_step = beam_impactx_surrogate_series.iterations[step].particles["beam"].to_df()
ref_part_step = impactx_surrogate_ref_particle.loc[step]
ref_u = np.sqrt(ref_part_step["pt"] ** 2 - 1)
to_t(
ref_u,
ref_part_step["pt"],
beam_at_step,
ref_z=ref_part_step["z"],
coord_type=TCoords.GLOBAL,
)
t_offset = impactx_surrogate_ref_particle.loc[step, "t"] * micron
fig, axT = plt.subplots(3, 3, figsize=(10, 8))
fig.suptitle(f"initially, ct={impactx_surrogate_ref_particle.at[step, 't']:.2f} m")
plot_beam_df(
beam_at_step,
axT,
alpha=0.6,
color="red",
unit_z=1e6,
unit_z_label=r"$\xi$ ($\mu$m)",
t_offset=t_offset,
z_ticks=[-107.3, -106.6],
)
if args.save_png:
plt.savefig("initial_phase_spaces.png")
else:
plt.show()
####### final ###########
if args.stages_to_plot is not None:
for stage_i in args.stages_to_plot:
step = ise[stage_i]
beam_at_step = (
beam_impactx_surrogate_series.iterations[step].particles["beam"].to_df()
)
ref_part_step = impactx_surrogate_ref_particle.loc[step]
ref_u = np.sqrt(ref_part_step["pt"] ** 2 - 1)
to_t(
ref_u,
ref_part_step["pt"],
beam_at_step,
ref_z=ref_part_step["z"],
coord_type=TCoords.GLOBAL,
)
t_offset = impactx_surrogate_ref_particle.loc[step, "t"] * micron
fig, axT = plt.subplots(3, 3, figsize=(10, 8))
fig.suptitle(
f"stage {stage_i}, ct={impactx_surrogate_ref_particle.at[step, 't']:.2f} m"
)
plot_beam_df(
beam_at_step,
axT,
alpha=0.6,
color="red",
unit_z=1e6,
unit_z_label=r"$\xi$ ($\mu$m)",
t_offset=t_offset,
z_ticks=[-107.3, -106.6],
)
if args.save_png:
plt.savefig(f"stage_{stage_i - 1}_phase_spaces.png")
else:
plt.show()

Fig. 14 Evolution of electron beam moments through 15 stages of LPAs (via neural network surrogates).

Fig. 15 Initial phase space projections going into 15 stage LPA (via neural network surrogates) simulation. Top row: spatial projections, middle row: momentum projections, bottom row: phase spaces.

Fig. 16 Final phase space projections after 15 stage LPA (via neural network surrogates) simulation. Top row: spatial projections, middle row: momentum projections, bottom row: phase spaces.