"""Define :class:`Particle`, created by reading CST ParticleMonitor files."""
import logging
import math
from typing import Any
import numpy as np
import vedo
from numpy.typing import NDArray
from simultipac.constants import clight, qelem
from simultipac.particle_monitor.vector import Momentum, Position
from simultipac.plotter.plotter import Plotter
PartMonLine = tuple[str, str, str, str, str, str, str, str, str, str, str, str]
PartMonData = tuple[
float,
float,
float,
float,
float,
float,
float,
float,
float,
float,
int,
int,
]
[docs]
class Particle: # pylint: disable=too-many-instance-attributes
"""Holds evolution of position and adim momentum with time.
Position in :unit:`mm`, time in :unit:`ns`.
Attributes
----------
position :
Position of the particle during the simulation.
momentum :
Momentum of the particle during the simulation.
_masses :
Mass of particle at each time step. An error is raised if it changes
between two files.
mass :
Mass of the particle in :unit:`kg`.
mass_eV :
Mass of the particle in :unit:`eV`.
_charges :
Charge of particle at each time step. An error is raised if it changes
between two files.
charge :
Charge of the particle.
time :
Holds the time steps in :unit:`ns` corresponding to every value of
``pos``, ``mom``, etc.
particle_id :
Unique id for the particle.
source_id :
Gives information on how the particle was created.
extrapolated_times :
Times at which position and momentum are extrapolated.
"""
[docs]
def __init__(self, raw_line: PartMonLine) -> None:
"""Init from a line of a position_monitor file."""
self.extrapolated_times: np.ndarray | None = None
self._masses: list[float]
self.mass: float
self.mass_eV: float # pylint: disable=invalid-name
self._charges: list[float]
self.charge: float
self._macro_charge: list[float]
self._time: list[float]
self.time: NDArray[np.float64]
self.particle_id: int
self.source_id: int
_line = _str_to_correct_types(raw_line)
self.position = Position((_line[0],), (_line[1],), (_line[2],))
self.momentum = Momentum((_line[3],), (_line[4],), (_line[5],))
self._masses = [_line[6]]
self._charges = [_line[7]]
self._macro_charge = [_line[8]]
self._time = [_line[9]]
self.particle_id = _line[10]
self.source_id = _line[11]
self.alive_at_end = False
self.emission_cell_id: np.ndarray = np.array([], dtype=np.float64)
self.emission_point: np.ndarray = np.array([], dtype=np.uint32)
self.emission_angle: float = np.nan
self.collision_cell_id: np.ndarray = np.array([], dtype=np.uint32)
self.collision_point: np.ndarray = np.array([], dtype=np.float64)
self.collision_angle: float = np.nan
[docs]
def add_a_file(self, raw_line: PartMonLine) -> None:
"""Add a time-step/a file to the current Particle."""
line = _str_to_correct_types(raw_line)
self.position.append(line[0:3])
self.momentum.append(line[3:6])
self._masses.append(line[6])
self._charges.append(line[7])
self._macro_charge.append(line[8])
self._time.append(line[9])
[docs]
def finalize(self) -> None:
"""Post treat object for consistency checks, better data types."""
self._check_constanteness_of_some_attributes()
self.time = np.array(self._time)
self._switch_to_mm_ns_units()
if not _is_sorted(self.time):
self._sort_by_increasing_time_values()
[docs]
def _check_constanteness_of_some_attributes(self) -> None:
"""Ensure that mass and charge did not evolve during simulation."""
self.mass = _get_constant(self._masses)
self.mass_eV = self.mass * clight**2 / qelem
self.charge = _get_constant(self._charges)
@property
def macro_charge(self) -> NDArray[np.float64]:
"""Return macro charge as an array."""
return np.array(self._macro_charge)
[docs]
def _switch_to_mm_ns_units(self) -> None:
"""Change the system units to limit rounding errors.
.. warning::
In CST Particle Monitor files, the time is given in seconds *
1e-18 (aka nano-nanoseconds). Tested with CST units for time in
nanoseconds.
"""
self.position.normalize()
self.time *= 1e18
[docs]
def _sort_by_increasing_time_values(self) -> None:
"""Sort arrays by increasing time values."""
idx = np.argsort(self.time)
self.position.reorder(idx)
self.momentum.reorder(idx)
self._macro_charge = [self._macro_charge[i] for i in idx]
self.time = self.time[idx]
@property
def emission_energy(self) -> float:
"""Compute emission energy in :unit:`eV`."""
return self.momentum.emission_energy(self.mass_eV)
@property
def collision_energy(self) -> float:
"""Determine the impact energy in :unit:`eV`.
Returns
-------
energy: float
The last known energy in :unit:`eV`.
"""
return self.momentum.collision_energy(self.mass_eV)
[docs]
def extrapolate_pos_and_mom_one_time_step_further(self) -> None:
"""Extrapolate position and momentum by one time step.
CST PIC solves the motion with a leapfrog solver (source: Mohamad
Houssini from Keonys, private communication).
Several possibilities:
- ``pos`` corresponds to ``time`` and ``mom`` shifted by half
time-steps (most probable).
- ``mom`` corresponds to ``time`` and ``pos`` shifted by half
time-steps (also possible).
- ``pos`` or ``mom`` is interpolated so that both are expressed at
full ``time`` steps (what I will consider for now).
"""
n_extrapolated_points = 2
n_extrapolated_time_steps = 10
self.extrapolated_times = np.full(n_extrapolated_points, np.nan)
if self.time.shape[0] <= 1:
return
fit_end = self.time[-1]
time_step = self.time[-1] - self.time[-2]
extrapolated_time_end = fit_end + n_extrapolated_time_steps * time_step
self.extrapolated_times = np.linspace(
fit_end, extrapolated_time_end, n_extrapolated_points
)
delta_t = self.extrapolated_times - fit_end
self.position.extrapolate(self.momentum, delta_t)
n_time_steps_for_polynom_fitting = 3
poly_fit_deg = 2
if poly_fit_deg >= n_time_steps_for_polynom_fitting:
raise OSError(
f"You need at least {poly_fit_deg + 1} momentum and "
"time step(s) to extrapolate momentum with a degree "
f"{poly_fit_deg} polynom."
)
if n_time_steps_for_polynom_fitting > self.time.shape[0]:
return
self.momentum.extrapolate(
self.time,
self.extrapolated_times,
poly_fit_deg,
n_time_steps_for_polynom_fitting,
)
[docs]
def determine_if_alive_at_end(
self, max_time: float, tol: float = 1e-6
) -> None:
"""Determine if the particle collisioned before end of simulation.
This method sets :attr:`.alive_at_end` flag.
Parameters
----------
max_time :
Simulation end time in :unit:`ns`.
tol :
Tolerance in :unit:`ns`.
"""
if abs(max_time - self.time[-1]) < tol:
self.alive_at_end = True
[docs]
def find_collision(
self,
mesh: vedo.Mesh,
warn_no_collision: bool = True,
warn_multiple_collisions: bool = False,
**kwargs,
) -> None:
"""Find where the trajectory impacts the structure.
If the particle is alive at the end of the simulation, we do not even
try. If it has only one known time step, neither do we.
We first try to detect a collision between the last known position of
the particle and the last extrapolated position. If no collision is
found, we try to find it between the last known position and the
know position just before that.
.. note::
If the last extrapolated position is too far from the last known
position, several collisions may be detected.
.. todo::
Take only nearest cell instead of the one with the lowest ID as for
now.
Parameters
----------
mesh :
``vedo`` mesh object describing the structure of the rf system.
warn_no_collision :
If True, a warning is raised when the electron was not alive at the
end of the simulation, but no collision was detected. The default
is True.
warn_multiple_collisions :
To warn if several collisions were detected for the same particle.
Also remove all collisions but the first one. The default is True.
kwargs :
kwargs
"""
if self.alive_at_end:
return
if self.position.n_steps <= 1:
return
p_0 = self.position.last
assert self.position.extrapolated is not None
p_1 = self.position.extrapolated[-1]
collision_point, collision_cell = mesh.intersect_with_line(
p0=p_0, p1=p_1, return_ids=True, tol=0
)
if collision_point.shape[0] == 0:
if self.position.n_steps <= 2:
return
p_1 = p_0
p_0 = self.position.array[-2, :]
collision_point, collision_cell = mesh.intersect_with_line(
p0=p_0, p1=p_1, return_ids=True, tol=0
)
if warn_no_collision and collision_point.shape[0] == 0:
logging.info(f"No collision for particle {self.particle_id}.")
return
if collision_point.shape[0] > 1:
collision_point = collision_point[0, :]
collision_cell = collision_cell[0, np.newaxis]
if warn_multiple_collisions:
logging.warning(
"More than one collision for particle "
f"{self.particle_id}. Only considering the first."
)
self.collision_cell_id = collision_cell
self.collision_point = collision_point
return
[docs]
def compute_emission_angle(self, mesh: vedo.Mesh) -> None:
"""Compute the angle of emission."""
raise NotImplementedError
if self.collision_cell_id.shape[0] < 1:
return
direction = self.momentum.first
normal = mesh.cell_normals[self.emission_cell_id]
adjacent = normal.dot(direction)
opposite = np.linalg.norm(np.cross(normal, direction))
tan_theta = opposite / adjacent
self.emission_angle = abs(math.atan(tan_theta))
[docs]
def compute_collision_angle(self, mesh: vedo.Mesh) -> None:
"""Compute the angle of impact."""
if self.alive_at_end:
return
if self.collision_cell_id.shape[0] < 1:
return
direction = self.momentum.last
normal = mesh.cell_normals[self.collision_cell_id]
adjacent = normal.dot(direction)
opposite = np.linalg.norm(np.cross(normal, direction))
tan_theta = opposite / adjacent
self.collision_angle = abs(math.atan(tan_theta))
[docs]
def plot_trajectory(
self,
plotter: Plotter,
emission_color: str | None = None,
collision_color: str | None = None,
lw: int = 7,
r: int = 8,
**kwargs,
) -> Any:
"""Plot the trajectory of the particle in 3D.
Parameters
----------
plotter :
Objet realizing the plots.
emission_color :
If provided, the first known position is colored with this color.
collision_color :
If provided, the last known position is colored with this color.
collision_point :
If provided and ``collision_color`` is not ``None``, we plot this
point instead of the last of ``points``. This is useful when the
extrapolated time is large, and actuel collision point may differ
significantly from last position points.
lw :
Trajectory line width.
r :
Size of the emission/collision points.
"""
collision_point = self.collision_point
if collision_point.shape == (1, 3):
collision_point = collision_point[0]
return plotter.plot_trajectory(
points=self.position.to_list,
emission_color=emission_color if self.source_id != 0 else None,
collision_color=collision_color if not self.alive_at_end else None,
collision_point=collision_point,
lw=lw,
r=r,
**kwargs,
)
[docs]
def _str_to_correct_types(line: PartMonLine) -> PartMonData:
"""Convert the input line of strings to proper data types."""
corrected = (
float(line[0]),
float(line[1]),
float(line[2]),
float(line[3]),
float(line[4]),
float(line[5]),
float(line[6]),
float(line[7]),
float(line[8]),
float(line[9]),
int(line[10]),
int(line[11]),
)
return corrected
[docs]
def _get_constant(variables: list[float]) -> float:
"""Check that the list of floats is a constant, return constant."""
asarray = np.array(variables)
if not (asarray == asarray[0]).all():
raise ValueError
return asarray[0]
[docs]
def _is_sorted(array: np.ndarray) -> bool:
"""Check that given array is ordered (increasing values)."""
return (array == np.sort(array)).all()