15 Stage Laser-Plasma Accelerator Surrogate

This example models an electron beam accelerated through fifteen stages of laser-plasma accelerators with ideal plasma lenses providing the focusing between stages. For more details, see:

  • Sandberg R T, Lehe R, Mitchell C E, Garten M, Myers A, Qiang J, Vay J-L and Huebl A. Synthesizing Particle-in-Cell Simulations Through Learning and GPU Computing for Hybrid Particle Accelerator Beamlines. Proc. of Platform for Advanced Scientific Computing (PASC’24), submitted, 2024. arXiv:2402.17248

  • Sandberg R T, Lehe R, Mitchell C E, Garten M, Qiang J, Vay J-L and Huebl A. Hybrid Beamline Element ML-Training for Surrogates in the ImpactX Beam-Dynamics Code. 14th International Particle Accelerator Conference (IPAC’23), WEPA101, 2023. DOI:10.18429/JACoW-IPAC2023-WEPA101

A schematic with more information can be seen in the figure below:

Schematic of the 15 stages of laser-plasma accelerators.

Fig. 8 Schematic of the 15 stages of laser-plasma accelerators.

The laser-plasma accelerator elements are modeled with neural networks as surrogates. These networks are trained beforehand. In this example, pre-trained neural networks are downloaded from a Zenodo archive and saved in the models directory. For more about how these neural network surrogate models were created, see this description of a workflow for training neural networks from WarpX simulation data.

We use a 1 GeV electron beam with initial normalized rms emittance of 1 mm-mrad.

In this test, the initial and final values of \(\sigma_x\), \(\sigma_y\), \(\sigma_t\), \(\epsilon_x\), \(\epsilon_y\), and \(\epsilon_t\) must agree with nominal values.

Run

This example can only be run with Python:

  • Python script: python3 run_ml_surrogate_15_stage.py

For MPI-parallel runs, prefix these lines with mpiexec -n 4 ... or srun -n 4 ..., depending on the system.

Listing 109 You can copy this file from examples/pytorch_surrogate_model/run_ml_surrogate_15_stage.py.
#!/usr/bin/env python3
#
# Copyright 2022-2023 ImpactX contributors
# Authors: Ryan Sandberg, Axel Huebl, Chad Mitchell
# License: BSD-3-Clause-LBNL
#
# -*- coding: utf-8 -*-

import argparse

import amrex.space3d as amr

try:
    import cupy as cp

    cupy_available = True
except ImportError:
    cupy_available = False
import sys

import numpy as np
import scipy.optimize as opt
from impactx import (
    Config,
    CoordSystem,
    ImpactX,
    ImpactXParIter,
    coordinate_transformation,
    distribution,
    elements,
)
from surrogate_model_definitions import surrogate_model

try:
    import torch
except ImportError:
    print("Warning: Cannot import PyTorch. Skipping test.")
    sys.exit(0)
import zipfile
from urllib import request

parser = argparse.ArgumentParser()
parser.add_argument(
    "--num_particles",
    "-N",
    type=int,
    default=100000,
    help="number of particles to use in beam",
)
parser.add_argument(
    "--N_stages",
    "-ns",
    type=int,
    default=15,
    choices=range(1, 16),
    help="number of LPA stages to simulate",
)
args = parser.parse_args()
if Config.have_gpu and cupy_available:
    array = cp.array
    stack = cp.stack
    sqrt = cp.sqrt
    device = torch.device("cuda")
    if Config.gpu_backend == "SYCL":
        print("Warning: SYCL GPU backend not yet implemented for Python")
else:
    array = np.array
    stack = np.stack
    sqrt = np.sqrt
    device = None
if device is not None:
    print(f"device={device}")
else:
    print("device set to default, cpu")

N_stage = args.N_stages
tune_by_x_or_y = "x"
npart = args.num_particles
ebeam_lpa_z0 = -107e-6
L_plasma = 0.28
L_transport = 0.03
L_stage_period = L_plasma + L_transport
drift_after_LPA = 43e-6
L_surrogate = abs(ebeam_lpa_z0) + L_plasma + drift_after_LPA


def download_and_unzip(url, data_dir):
    request.urlretrieve(url, data_dir)
    with zipfile.ZipFile(data_dir, "r") as zip_dataset:
        zip_dataset.extractall()


data_url = "https://zenodo.org/records/10810754/files/models.zip?download=1"
download_and_unzip(data_url, "models.zip")

model_list = [
    surrogate_model(f"models/beam_stage_{stage_i}_model.pt", device=device)
    for stage_i in range(N_stage)
]

pp_amrex = amr.ParmParse("amrex")
pp_amrex.add("the_arena_init_size", 0)
pp_amrex.add("the_device_arena_init_size", 0)

sim = ImpactX()

# set numerical parameters and IO control
sim.particle_shape = 2  # B-spline order
sim.space_charge = False
sim.diagnostics = True  # benchmarking
sim.slice_step_diagnostics = True

# domain decomposition & space charge mesh
sim.init_grids()

# load a 1 GeV electron beam with an initial
# unnormalized rms emittance of 1 nm
ref_u = 1957
energy_gamma = np.sqrt(1 + ref_u**2)
energy_MeV = 0.510998950 * energy_gamma  # reference energy
bunch_charge_C = 10.0e-15  # used with space charge


#   reference particle
ref = sim.particle_container().ref_particle()
ref.set_charge_qe(-1.0).set_mass_MeV(0.510998950).set_kin_energy_MeV(energy_MeV)
ref.z = ebeam_lpa_z0

pc = sim.particle_container()

distr = distribution.Gaussian(
    lambdaX=0.75e-6,
    lambdaY=0.75e-6,
    lambdaT=0.1e-6,
    lambdaPx=1.33 / energy_gamma,
    lambdaPy=1.33 / energy_gamma,
    lambdaPt=1e-8,
    muxpx=0.0,
    muypy=0.0,
    mutpt=0.0,
)
sim.add_particles(bunch_charge_C, distr, npart)

n_slice = 1


class LPASurrogateStage(elements.Programmable):
    def __init__(self, stage_i, surrogate_model, surrogate_length, stage_start):
        elements.Programmable.__init__(self)
        self.stage_i = stage_i
        self.surrogate_model = surrogate_model
        self.surrogate_length = surrogate_length
        self.stage_start = stage_start
        self.push = self.surrogate_push
        self.ds = surrogate_length

    def surrogate_push(self, pc, step):

        ref_part = pc.ref_particle()
        ref_z_i = ref_part.z
        ref_z_i_LPA = ref_z_i - self.stage_start
        ref_z_f = ref_z_i + self.surrogate_length
        ref_part_tensor = torch.tensor(
            [
                ref_part.x,
                ref_part.y,
                ref_z_i_LPA,
                ref_part.px,
                ref_part.py,
                ref_part.pz,
            ],
            device=device,
            dtype=torch.float64,
        )
        ref_beta_gamma = torch.sqrt(torch.sum(ref_part_tensor[3:] ** 2))
        ref_beta_gamma = ref_beta_gamma.to(device)

        with torch.no_grad():
            ref_part_model_final = self.surrogate_model(ref_part_tensor)
        ref_uz_f = ref_part_model_final[5]
        ref_beta_gamma_final = ref_uz_f
        ref_part_final = torch.tensor(
            [0, 0, ref_z_f, 0, 0, ref_uz_f], device=device, dtype=torch.float64
        )

        coordinate_transformation(pc, CoordSystem.t)

        for lvl in range(pc.finest_level + 1):
            for pti in ImpactXParIter(pc, level=lvl):
                soa = pti.soa()
                real_arrays = soa.get_real_data()
                x = array(real_arrays[0], copy=False)
                y = array(real_arrays[1], copy=False)
                t = array(real_arrays[2], copy=False)
                px = array(real_arrays[3], copy=False)
                py = array(real_arrays[4], copy=False)
                pt = array(real_arrays[5], copy=False)
                data_arr = torch.tensor(
                    stack([x, y, t, px, py, pt], axis=1),
                    device=device,
                    dtype=torch.float64,
                )

                data_arr[:, 0] += ref_part.x
                data_arr[:, 1] += ref_part.y
                data_arr[:, 2] += ref_z_i_LPA
                data_arr[:, 3:] *= ref_beta_gamma
                data_arr[:, 3] += ref_part.px
                data_arr[:, 4] += ref_part.py
                data_arr[:, 5] += ref_part.pz

                with torch.no_grad():
                    data_arr_post_model = self.surrogate_model(data_arr)

                #  z += stage start
                data_arr_post_model[:, 2] += self.stage_start
                # back to ref particle coordinates
                for ii in range(3):
                    data_arr_post_model[:, ii] -= ref_part_final[ii]
                    data_arr_post_model[:, 3 + ii] -= ref_part_final[3 + ii]
                    data_arr_post_model[:, 3 + ii] /= ref_beta_gamma_final

                x[:] = array(data_arr_post_model[:, 0])
                y[:] = array(data_arr_post_model[:, 1])
                t[:] = array(data_arr_post_model[:, 2])
                px[:] = array(data_arr_post_model[:, 3])
                py[:] = array(data_arr_post_model[:, 4])
                pt[:] = array(data_arr_post_model[:, 5])

        # TODO this part needs to be corrected for general geometry
        # where the initial vector might not point in z
        # and even if it does, bending elements may change the direction

        ref_part.x = ref_part_final[0]
        ref_part.y = ref_part_final[1]
        ref_part.z = ref_part_final[2]
        ref_gamma = torch.sqrt(1 + ref_beta_gamma_final**2)
        ref_part.px = ref_part_final[3]
        ref_part.py = ref_part_final[4]
        ref_part.pz = ref_part_final[5]
        ref_part.pt = -ref_gamma
        ref_part.s += self.surrogate_length
        ref_part.t += self.surrogate_length

        coordinate_transformation(pc, CoordSystem.s)
        ## Done!


L_transport = 0.03
L_lens = 0.003
L_focal = 0.5 * L_transport
L_drift = 0.5 * (L_transport - L_lens)
K = np.sqrt(2.0 / L_focal / L_lens)
Kt = 1e-11  # number chosen arbitrarily since 0 isn't allowed
dL = 0

L_drift_minus_surrogate = L_drift
L_drift_1 = L_drift - drift_after_LPA - dL

L_drift_before_2nd_stage = abs(ebeam_lpa_z0)
L_drift_2 = L_drift - L_drift_before_2nd_stage + dL


def get_lattice_element_iter(sim, j):
    assert (
        0 <= j < len(sim.lattice)
    ), f"Argument j must be a nonnegative integer satisfying 0 <= j < {len(sim.lattice)}, not {j}"
    i = 0
    lat_it = sim.lattice.__iter__()
    next(lat_it)
    while i != j:
        next(lat_it)
        i += 1
    return lat_it


def lens_eqn(k, lens_length, alpha, beta, gamma):
    return np.tan(k * lens_length) + 2 * alpha / (k * beta - gamma / k)


k_list = []


class UpdateConstF(elements.Programmable):
    def __init__(self, sim, stage_i, lattice_index, x_or_y):
        elements.Programmable.__init__(self)
        self.sim = sim
        self.stage_i = stage_i
        self.lattice_index = lattice_index
        self.x_or_y = x_or_y
        self.push = self.set_lens

    def set_lens(self, step):
        pc = self.sim.particle_container()
        # get envelope parameters
        rbc = pc.reduced_beam_characteristics()
        alpha = rbc[f"alpha_{self.x_or_y}"]
        beta = rbc[f"beta_{self.x_or_y}"]
        gamma = (1 + alpha**2) / beta
        # solve for k_new
        sol = opt.root_scalar(
            lens_eqn, bracket=[100, 300], args=(L_lens, alpha, beta, gamma)
        )
        k_new = sol.root
        # set lens
        self_it = get_lattice_element_iter(self.sim, self.lattice_index)
        following_lens = next(self_it)
        k_list.append(k_new)
        following_lens.kx = k_new
        following_lens.ky = k_new


lpa_stages = []
for i in range(N_stage):
    lpa = LPASurrogateStage(i, model_list[i], L_surrogate, L_stage_period * i)
    lpa.nslice = n_slice
    lpa.ds = L_surrogate
    lpa_stages.append(lpa)

monitor = elements.BeamMonitor("monitor")
for i in range(N_stage):
    sim.lattice.extend(
        [
            monitor,
            lpa_stages[i],
        ]
    )

    if i != N_stage - 1:
        sim.lattice.extend(
            [
                monitor,
                elements.Drift(ds=L_drift_1),
                monitor,
                UpdateConstF(
                    sim=sim, stage_i=i, lattice_index=5 + 9 * i, x_or_y=tune_by_x_or_y
                ),
                elements.ConstF(ds=L_lens, kx=K, ky=K, kt=Kt),
                monitor,
                elements.Drift(ds=L_drift_2),
            ]
        )
sim.lattice.extend([monitor])

sim.evolve()
sim.finalize()
del sim

This script requires some utility code for using the neural networks that is provided here:

Script surrogate_model_definitions.py
Listing 110 You can copy this file from examples/pytorch_surrogate_model/surrogate_model_definitions.py.
#!/usr/bin/env python3
#
# Copyright 2022-2023 ImpactX contributors
# Authors: Ryan Sandberg, Axel Huebl
# License: BSD-3-Clause-LBNL
#
# -*- coding: utf-8 -*-

from enum import Enum

import torch
from torch import nn


class Activation(Enum):
    """
    Activation class provides an enumeration type for the supported activation layers
    """

    ReLU = 1
    Tanh = 2
    PReLU = 3
    Sigmoid = 4


def get_enum_type(type_to_test, EnumClass):
    """
    Returns the enumeration type associated to type_to_test in EnumClass

    Parameters
    ----------
    type_to_test: EnumClass, int or str
        object whose Enum class is to be obtained
    EnumClass: Enum class
        Enum class to test
    """
    if isinstance(type_to_test, EnumClass):  ## Useful ?
        return type_to_test
    if isinstance(type_to_test, int):
        return EnumClass(type_to_test)
    if isinstance(type_to_test, str):
        return getattr(EnumClass, type_to_test)


class ConnectedNN(nn.Module):
    """
    ConnectedNN is a class of fully connected neural networks
    """

    def __init__(self, layers, device=None):
        super().__init__()
        self.stack = nn.Sequential(*layers)
        if device is not None:
            self.to(device)

    def forward(self, x):
        return self.stack(x)


class OneActNN(ConnectedNN):
    """
    OneActNN is class of fully connected neural networks admitting only one activation function
    """

    def __init__(self, n_in, n_out, n_hidden_nodes, n_hidden_layers, act, device=None):
        self.n_in = n_in
        self.n_out = n_out
        self.n_hidden_layers = n_hidden_layers
        self.n_hidden_nodes = n_hidden_nodes
        self.act = act

        layers = [nn.Linear(self.n_in, self.n_hidden_nodes)]

        for ii in range(self.n_hidden_layers):
            if self.act is Activation.ReLU:
                layers += [nn.ReLU()]
            if self.act is Activation.Tanh:
                layers += [nn.Tanh()]
            if self.act is Activation.PReLU:
                layers += [nn.PReLU()]
            if self.act is Activation.Sigmoid:
                layers += [nn.Sigmoid()]

            if ii < self.n_hidden_layers - 1:
                layers += [nn.Linear(self.n_hidden_nodes, self.n_hidden_nodes)]

        layers += [nn.Linear(self.n_hidden_nodes, self.n_out)]

        super().__init__(layers, device)


class surrogate_model:
    """
    Extend the functionality of the OneActNN class

    This class is meant to act as a wrapper for the OneActNN class.
    It provides a `__call__` function that normalizes input and returns dimensional output.
    """

    def __init__(self, model_file, device=None):
        self.device = device
        if device is None:
            model_dict = torch.load(model_file, map_location="cpu")
        else:
            model_dict = torch.load(model_file, map_location=device)
        self.source_means = torch.tensor(
            model_dict["source_means"], device=self.device, dtype=torch.float64
        )
        self.target_means = torch.tensor(
            model_dict["target_means"], device=self.device, dtype=torch.float64
        )
        self.source_stds = torch.tensor(
            model_dict["source_stds"], device=self.device, dtype=torch.float64
        )
        self.target_stds = torch.tensor(
            model_dict["target_stds"], device=self.device, dtype=torch.float64
        )
        n_in = model_dict["model_state_dict"]["stack.0.weight"].shape[1]
        final_layer_key = list(model_dict["model_state_dict"].keys())[-1]
        n_out = model_dict["model_state_dict"][final_layer_key].shape[0]
        n_hidden_nodes = model_dict["model_state_dict"]["stack.0.weight"].shape[0]
        activation_type = model_dict["activation"]
        activation = get_enum_type(activation_type, Activation)
        if "n_hidden_layers" in model_dict.keys():
            n_hidden_layers = model_dict["n_hidden_layers"]
        else:
            if activation is Activation.PReLU:
                n_hidden_layers = int(
                    (len(model_dict["model_state_dict"].keys()) - 2) / 3
                )
            else:
                n_hidden_layers = int(
                    len(model_dict["model_state_dict"].keys()) / 2 - 1
                )

        self.neural_network = OneActNN(
            n_in=n_in,
            n_out=n_out,
            n_hidden_nodes=n_hidden_nodes,
            n_hidden_layers=n_hidden_layers,
            act=activation,
            device=device,
        )
        self.neural_network.load_state_dict(model_dict["model_state_dict"])
        self.neural_network.eval()

    def __call__(self, data_arr):
        data_arr -= self.source_means
        data_arr /= self.source_stds
        with torch.no_grad():
            data_arr_post_model = self.neural_network(data_arr.float()).double()
        data_arr_post_model *= self.target_stds
        data_arr_post_model += self.target_means
        return data_arr_post_model

Analyze

We run the following script to analyze correctness:

Script analyze_ml_surrogate_15_stage.py
Listing 111 You can copy this file from examples/pytorch_surrogate_model/analyze_ml_surrogate_15_stage.py.
#!/usr/bin/env python3
#
# Copyright 2022-2023 ImpactX contributors
# Authors: Ryan Sandberg, Axel Huebl, Chad Mitchell
# License: BSD-3-Clause-LBNL
#
# -*- coding: utf-8 -*-

import numpy as np
import openpmd_api as io
from scipy.stats import moment


def get_moments(beam):
    """Calculate standard deviations of beam position & momenta
    and emittance values

    Returns
    -------
    sigx, sigy, sigt, emittance_x, emittance_y, emittance_t
    """
    sigx = moment(beam["position_x"], moment=2) ** 0.5  # variance -> std dev.
    sigpx = moment(beam["momentum_x"], moment=2) ** 0.5
    sigy = moment(beam["position_y"], moment=2) ** 0.5
    sigpy = moment(beam["momentum_y"], moment=2) ** 0.5
    sigt = moment(beam["position_t"], moment=2) ** 0.5
    sigpt = moment(beam["momentum_t"], moment=2) ** 0.5

    epstrms = beam.cov(ddof=0)
    emittance_x = (sigx**2 * sigpx**2 - epstrms["position_x"]["momentum_x"] ** 2) ** 0.5
    emittance_y = (sigy**2 * sigpy**2 - epstrms["position_y"]["momentum_y"] ** 2) ** 0.5
    emittance_t = (sigt**2 * sigpt**2 - epstrms["position_t"]["momentum_t"] ** 2) ** 0.5

    return (sigx, sigy, sigt, emittance_x, emittance_y, emittance_t)


# initial/final beam
series = io.Series("diags/openPMD/monitor.bp", io.Access.read_only)
last_step = list(series.iterations)[-1]
initial = series.iterations[1].particles["beam"].to_df()
final = series.iterations[last_step].particles["beam"].to_df()

# compare number of particles
num_particles = 100000
assert num_particles == len(initial)
assert num_particles == len(final)

print("Initial Beam:")
sigx, sigy, sigt, emittance_x, emittance_y, emittance_t = get_moments(initial)
print(f"  sigx={sigx:e} sigy={sigy:e} sigt={sigt:e}")
print(
    f"  emittance_x={emittance_x:e} emittance_y={emittance_y:e} emittance_t={emittance_t:e}"
)

atol = 0.0  # ignored
rtol = num_particles**-0.5  # from random sampling of a smooth distribution
print(f"  rtol={rtol} (ignored: atol~={atol})")

assert np.allclose(
    [sigx, sigy, sigt, emittance_x, emittance_y],
    [
        7.494325e-07,
        7.478916e-07,
        9.976192e-08,
        5.070297e-10,
        5.080007e-10,
    ],
    rtol=rtol,
    atol=atol,
)

atol = 1.0e-6
print(f"  atol~={atol}")
assert np.allclose([emittance_t], [0.0], atol=atol)

print("")
print("Final Beam:")
sigx, sigy, sigt, emittance_x, emittance_y, emittance_t = get_moments(final)
print(f"  sigx={sigx:e} sigy={sigy:e} sigt={sigt:e}")
print(
    f"  emittance_x={emittance_x:e} emittance_y={emittance_y:e} emittance_t={emittance_t:e}"
)

atol = 0.0  # ignored
rtol = num_particles**-0.5  # from random sampling of a smooth distribution
print(f"  rtol={rtol} (ignored: atol~={atol})")

assert np.allclose(
    [sigx, sigy, sigt, emittance_x, emittance_y, emittance_t],
    [
        1.590999e-07,
        1.634865e-07,
        1.030930e-07,
        5.031797e-12,
        5.242205e-12,
        2.049623e-11,
    ],
    rtol=rtol,
    atol=atol,
)

Visualize

You can run the following script to visualize the beam evolution over time:

Script visualize_ml_surrogate_15_stage.py
Listing 112 You can copy this file from examples/pytorch_surrogate_model/visualize_ml_surrogate_15_stage.py.
#!/usr/bin/env python3
#
# Copyright 2022-2023 ImpactX contributors
# Authors: Ryan Sandberg, Axel Huebl, Chad Mitchell
# License: BSD-3-Clause-LBNL
#
# -*- coding: utf-8 -*-

import argparse
import glob
import re

import numpy as np
import openpmd_api as io
import pandas as pd
from matplotlib import pyplot as plt
from scipy.constants import c, e, m_e


def read_all_files(file_pattern):
    """Read in all CSV files from each MPI rank (and potentially OpenMP
    thread). Concatenate into one Pandas dataframe.
    Returns
    -------
    pandas.DataFrame
    """
    return pd.concat(
        (
            pd.read_csv(filename, delimiter=r"\s+")
            for filename in glob.glob(file_pattern)
        ),
        axis=0,
        ignore_index=True,
    ).set_index("id")


def read_file(file_pattern):
    for filename in glob.glob(file_pattern):
        df = pd.read_csv(filename, delimiter=r"\s+")
        if "step" not in df.columns:
            step = int(re.findall(r"[0-9]+", filename)[0])
            df["step"] = step
        yield df


def read_time_series(file_pattern):
    """Read in all CSV files from each MPI rank (and potentially OpenMP
    thread). Concatenate into one Pandas dataframe.

    Returns
    -------
    pandas.DataFrame
    """
    return pd.concat(
        read_file(file_pattern),
        axis=0,
        ignore_index=True,
    )  # .set_index('id')


from enum import Enum


class TCoords(Enum):
    REF = 1
    GLOBAL = 2


def to_t(
    ref_pz, ref_pt, data_arr_s, ref_z=None, coord_type=TCoords.REF
):  # x, y, t, dpx, dpy, dpt):
    """Change to fixed t coordinates

    Parameters
    ---
    ref_pz: float, reference particle momentum in z
    ref_pt: float, reference particle pt = -gamma
    data_arr_s: Nx6 array-like structure containing fixed-s particle coordinates
    ref_z: if transforming to global coordinates
    coord_type: TCoords enum, (default is in ref coordinates) whether to get particle data relative to reference coordinate or in the global frame
    """
    if type(data_arr_s) is pd.core.frame.DataFrame:
        dx = data_arr_s["position_x"]
        dy = data_arr_s["position_y"]
        dt = data_arr_s["position_t"]
        dpx = data_arr_s["momentum_x"]
        dpy = data_arr_s["momentum_y"]
        dpt = data_arr_s["momentum_t"]

    elif type(data_arr_s) is np.ndarray:
        assert (
            data_arr_s.shape[1] == 6
        ), f"data_arr_s.shape={data_arr_s.shape} but data_arr_s must be an Nx6 array"
        dx, dy, dt, dpx, dpy, dpt = data_arr_s.T
    else:
        raise Exception(
            f"Incompatible input type {type(data_arr_s)} for data_arr_s, must be pandas DataFrame or Nx6 array-like object"
        )
    dx += ref_pz * dpx * dt / (ref_pt + ref_pz * dpt)
    dy += ref_pz * dpy * dt / (ref_pt + ref_pz * dpt)
    pz = np.sqrt(
        -1 + (ref_pt + ref_pz * dpt) ** 2 - (ref_pz * dpx) ** 2 - (ref_pz * dpy) ** 2
    )
    dt *= pz / (ref_pt + ref_pz * dpt)
    if type(data_arr_s) is pd.core.frame.DataFrame:
        data_arr_s["momentum_t"] = pz - ref_pz
        dpt = data_arr_s["momentum_t"]
    else:
        dpt[:] = pz - ref_pz
    if coord_type is TCoords.REF:
        print("applying reference normalization")
        dpt /= ref_pz
    elif coord_type is TCoords.GLOBAL:
        assert (
            ref_z is not None
        ), "Reference particle z coordinate is required to transform to global coordinates"
        print("target global coordinates")
        dt += ref_z
        dpx *= ref_pz
        dpy *= ref_pz
        dpt += ref_pz
    # data_arr_t = np.column_stack([xt,yt,z,dpx,dpy,dpz])
    return  # modifies data_arr_s in place


def plot_beam_df(
    beam_at_step,
    axT,
    unit=1e6,
    unit_z=1e3,
    unit_label="$\mu$m",
    unit_z_label="mm",
    alpha=1.0,
    cmap=None,
    color="k",
    size=0.1,
    t_offset=0.0,
    label=None,
    z_ticks=None,
):
    ax = axT[0][0]
    ax.scatter(
        beam_at_step.position_x.multiply(unit),
        beam_at_step.position_y.multiply(unit),
        c=color,
        s=size,
        alpha=alpha,
        cmap=cmap,
    )
    ax.set_xlabel(r"x (%s)" % unit_label)
    ax.set_ylabel(r"y (%s)" % unit_label)
    ax.axes.ticklabel_format(axis="both", style="sci", scilimits=(-2, 2))
    ###########

    ax = axT[0][1]
    ax.scatter(
        beam_at_step.position_t.multiply(unit_z) - t_offset,
        beam_at_step.position_x.multiply(unit),
        c=color,
        s=size,
        alpha=alpha,
        cmap=cmap,
    )
    ax.set_xlabel(r"%s" % unit_z_label)
    ax.set_ylabel(r"x (%s)" % unit_label)
    ax.axes.ticklabel_format(axis="y", style="sci", scilimits=(-2, 2))
    if z_ticks is not None:
        ax.set_xticks(z_ticks)
    ###########

    ax = axT[0][2]
    ax.scatter(
        beam_at_step.position_t.multiply(unit_z) - t_offset,
        beam_at_step.position_y.multiply(unit),
        c=color,
        s=size,
        alpha=alpha,
        cmap=cmap,
    )
    ax.set_xlabel(r"%s" % unit_z_label)
    ax.set_ylabel(r"y (%s)" % unit_label)
    ax.axes.ticklabel_format(axis="y", style="sci", scilimits=(-2, 2))
    if z_ticks is not None:
        ax.set_xticks(z_ticks)
    ############
    ##########
    ax = axT[1][0]
    ax.scatter(
        beam_at_step.momentum_x,
        beam_at_step.momentum_y,
        c=color,
        s=size,
        alpha=alpha,
        cmap=cmap,
    )
    ax.set_xlabel("px")
    ax.set_ylabel("py")
    ax.axes.ticklabel_format(axis="both", style="sci", scilimits=(-2, 2))
    ##########
    ax = axT[1][1]
    ax.scatter(
        beam_at_step.momentum_t,
        #         beam_at_step.position_t.multiply(unit_z)-t_offset,
        beam_at_step.momentum_x,
        c=color,
        s=size,
        alpha=alpha,
        cmap=cmap,
    )
    ax.set_xlabel("pt")
    #     ax.set_xlabel(r'%s'%unit_z_label)
    ax.set_ylabel("px")
    ax.axes.ticklabel_format(axis="both", style="sci", scilimits=(-2, 2))
    ##########
    ax = axT[1][2]
    ax.scatter(
        beam_at_step.momentum_t,
        #         beam_at_step.position_t.multiply(unit_z)-t_offset,
        beam_at_step.momentum_y,
        c=color,
        s=size,
        alpha=alpha,
        label=label,
        cmap=cmap,
    )
    if label is not None:
        ax.legend()
    #     ax.set_xlabel(r'%s'%unit_z_label)
    ax.set_xlabel("pt")
    ax.set_ylabel("py")
    ax.axes.ticklabel_format(axis="both", style="sci", scilimits=(-2, 2))
    ############
    ############
    ##########

    ax = axT[2][0]
    ax.scatter(
        beam_at_step.position_x.multiply(unit),
        beam_at_step.momentum_x,
        c=color,
        s=size,
        alpha=alpha,
        cmap=cmap,
    )
    ax.set_xlabel(r"x (%s)" % unit_label)
    ax.set_ylabel("px")
    ax.axes.ticklabel_format(axis="both", style="sci", scilimits=(-2, 2))
    ############
    ax = axT[2][1]
    ax.scatter(
        beam_at_step.position_y.multiply(unit),
        beam_at_step.momentum_y,
        c=color,
        s=size,
        alpha=alpha,
        cmap=cmap,
    )
    ax.set_xlabel(r"y (%s)" % unit_label)
    ax.set_ylabel("py")
    ax.axes.ticklabel_format(axis="both", style="sci", scilimits=(-2, 2))

    ################
    ax = axT[2][2]
    ax.scatter(
        beam_at_step.position_t.multiply(unit_z) - t_offset,
        beam_at_step.momentum_t,
        c=color,
        s=size,
        alpha=alpha,
        cmap=cmap,
    )
    ax.set_xlabel(r"%s" % unit_z_label)
    ax.set_ylabel("pt")
    ax.axes.ticklabel_format(axis="y", style="sci", scilimits=(-2, 2))
    if z_ticks is not None:
        ax.set_xticks(z_ticks)
    plt.tight_layout()
    # done


# options to run this script
parser = argparse.ArgumentParser(description="Plot the ML surrogate benchmark.")
parser.add_argument(
    "--save-png", action="store_true", help="non-interactive run: save to PNGs"
)
parser.add_argument(
    "--num-stages", "-n", type=int, default=15, help="num stages to plot"
)
parser.add_argument(
    "--stages_to_plot", "-s", type=int, help="num stages to plot", nargs="*"
)
args = parser.parse_args()

impactx_surrogate_reduced_diags = read_time_series(
    "diags/reduced_beam_characteristics.*"
)
ref_gamma = np.sqrt(1 + impactx_surrogate_reduced_diags["ref_beta_gamma"] ** 2)
beam_gamma = (
    ref_gamma
    - impactx_surrogate_reduced_diags["pt_mean"]
    * impactx_surrogate_reduced_diags["ref_beta_gamma"]
)
beam_u = np.sqrt(beam_gamma**2 - 1)
emit_x = impactx_surrogate_reduced_diags["emittance_x"]
emit_nx = emit_x * beam_u
emit_y = impactx_surrogate_reduced_diags["emittance_y"]
emit_ny = emit_y * beam_u

ix_slice = [0] + [2 + 9 * i for i in range(args.num_stages)]

############# plot moments ##############
fig, axT = plt.subplots(2, 2, figsize=(10, 8))
ymarker = "^"
######### emittance ##########
ax = axT[0][0]
scale = 1e6
ax.plot(
    impactx_surrogate_reduced_diags["s"][ix_slice],
    emit_nx[ix_slice] * scale,
    "bo",
    label="x",
)
ax.plot(
    impactx_surrogate_reduced_diags["s"][ix_slice],
    emit_ny[ix_slice] * scale,
    "r",
    marker=ymarker,
    linestyle="None",
    label="y",
)
ax.legend()
ax.set_xlabel("s (m)")
ax.set_ylabel(r"emittance (mm-mrad)")
######### energy ##########
ax = axT[0][1]
scale = m_e * c**2 / e * 1e-9
ax.plot(
    impactx_surrogate_reduced_diags["s"][ix_slice],
    beam_gamma[ix_slice] * scale,
    "go",
)
ax.set_xlabel("s (m)")
ax.set_ylabel(r"mean energy (GeV)")

######### width ##########
ax = axT[1][0]
scale = 1e6
ax.plot(
    impactx_surrogate_reduced_diags["s"][ix_slice],
    impactx_surrogate_reduced_diags["sig_x"][ix_slice] * scale,
    "bo",
    label="x",
)
ax.plot(
    impactx_surrogate_reduced_diags["s"][ix_slice],
    impactx_surrogate_reduced_diags["sig_y"][ix_slice] * scale,
    "r",
    marker=ymarker,
    linestyle="None",
    label="y",
)
ax.legend()
ax.set_xlabel("s (m)")
ax.set_ylabel(r"beam width ($\mu$m)")

######### divergence ##########
ax = axT[1][1]
scale = 1e3
ax.semilogy(
    impactx_surrogate_reduced_diags["s"][ix_slice],
    impactx_surrogate_reduced_diags["sig_px"][ix_slice] * scale,
    "bo",
    label="x",
)
ax.semilogy(
    impactx_surrogate_reduced_diags["s"][ix_slice],
    impactx_surrogate_reduced_diags["sig_py"][ix_slice] * scale,
    "r",
    marker=ymarker,
    linestyle="None",
    label="y",
)
ax.legend()
ax.set_xlabel("s (m)")
ax.set_ylabel(r"divergence (mrad)")

plt.tight_layout()

if args.save_png:
    plt.savefig("lpa_ml_surrogate_moments.png")
else:
    plt.show()


######## plot phase spaces ###########
beam_impactx_surrogate_series = io.Series(
    "diags/openPMD/monitor.bp", io.Access.read_only
)
impactx_surrogate_steps = list(beam_impactx_surrogate_series.iterations)
impactx_surrogate_ref_particle = read_time_series("diags/ref_particle.*")

millimeter = 1.0e3
micron = 1.0e6

N_stage = args.num_stages
impactx_stage_end_steps = [1] + [3 + 9 * i for i in range(N_stage)]
ise = impactx_stage_end_steps

# initial

step = 1
beam_at_step = beam_impactx_surrogate_series.iterations[step].particles["beam"].to_df()
ref_part_step = impactx_surrogate_ref_particle.loc[step]
ref_u = np.sqrt(ref_part_step["pt"] ** 2 - 1)
to_t(
    ref_u,
    ref_part_step["pt"],
    beam_at_step,
    ref_z=ref_part_step["z"],
    coord_type=TCoords.GLOBAL,
)

t_offset = impactx_surrogate_ref_particle.loc[step, "t"] * micron
fig, axT = plt.subplots(3, 3, figsize=(10, 8))
fig.suptitle(f"initially, ct={impactx_surrogate_ref_particle.at[step,'t']:.2f} m")

plot_beam_df(
    beam_at_step,
    axT,
    alpha=0.6,
    color="red",
    unit_z=1e6,
    unit_z_label=r"$\xi$ ($\mu$m)",
    t_offset=t_offset,
    z_ticks=[-107.3, -106.6],
)
if args.save_png:
    plt.savefig("initial_phase_spaces.png")
else:
    plt.show()

####### final ###########
if args.stages_to_plot is not None:
    for stage_i in args.stages_to_plot:
        step = ise[stage_i]
        beam_at_step = (
            beam_impactx_surrogate_series.iterations[step].particles["beam"].to_df()
        )
        ref_part_step = impactx_surrogate_ref_particle.loc[step]
        ref_u = np.sqrt(ref_part_step["pt"] ** 2 - 1)
        to_t(
            ref_u,
            ref_part_step["pt"],
            beam_at_step,
            ref_z=ref_part_step["z"],
            coord_type=TCoords.GLOBAL,
        )

        t_offset = impactx_surrogate_ref_particle.loc[step, "t"] * micron
        fig, axT = plt.subplots(3, 3, figsize=(10, 8))
        fig.suptitle(
            f"stage {stage_i}, ct={impactx_surrogate_ref_particle.at[step,'t']:.2f} m"
        )

        plot_beam_df(
            beam_at_step,
            axT,
            alpha=0.6,
            color="red",
            unit_z=1e6,
            unit_z_label=r"$\xi$ ($\mu$m)",
            t_offset=t_offset,
            z_ticks=[-107.3, -106.6],
        )
        if args.save_png:
            plt.savefig(f"stage_{stage_i-1}_phase_spaces.png")
        else:
            plt.show()
Evolution of beam moments through 15 stage LPA via neural network surrogates.

Fig. 9 Evolution of electron beam moments through 15 stages of LPAs (via neural network surrogates).

Initial phase space projections

Fig. 10 Initial phase space projections going into 15 stage LPA (via neural network surrogates) simulation. Top row: spatial projections, middle row: momentum projections, bottom row: phase spaces.

Final phase space projections after 15 stage LPA (via neural network surrogates) simulation

Fig. 11 Final phase space projections after 15 stage LPA (via neural network surrogates) simulation. Top row: spatial projections, middle row: momentum projections, bottom row: phase spaces.