Source code for simultipac.plotter.default

"""Define a default plotter."""

import logging
from pathlib import Path
from typing import Any, Literal

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import vedo
import vedo.backends
from matplotlib.axes import Axes
from matplotlib.typing import ColorType
from numpy.typing import NDArray

from simultipac.constants import markdown
from simultipac.plotter.plotter import Plotter
from simultipac.types import PARTICLE_0D_t, PARTICLE_3D_t

VEDO_BACKENDS_t = Literal["k3d", "vtk", "2d"]


[docs] class DefaultPlotter(Plotter): """An object using maptlotlib for 2D, Vedo for 3D."""
[docs] def __init__( self, vedo_backend: VEDO_BACKENDS_t = "2d", *args, **kwargs ) -> None: """Set basic settings for the 3D Vedo plotter. Parameters ---------- vedo_backend : The backend used by ``vedo``. The options that I tested were: - ``"k3d"``: Needs additional libraries (``pip install -e .[k3d]``). A little bugged, see :meth:`._k3d_patch`. - ``"vtk"``: Interactive 3D plots. - ``"2d"``: Non-interactive 2D plots. """ self._vedo_backend: VEDO_BACKENDS_t self.vedo_backend = vedo_backend self._plotter_3d = vedo.Plotter() self._show_3d = False return super().__init__(*args, **kwargs)
@property def vedo_backend(self) -> VEDO_BACKENDS_t: """The name of the vedo backend; *a priori*, no need to access that.""" return self._vedo_backend @vedo_backend.setter def vedo_backend(self, value: VEDO_BACKENDS_t) -> None: """Update the vedo backend.""" vedo.settings.default_backend = value self._vedo_backend = value if value in ("k3d",): self._k3d_patch()
[docs] def plot( self, data: pd.DataFrame, x: str, y: str, grid: bool = True, axes: Axes | None = None, xlabel: str | None = None, ylabel: str | None = None, label: str | None = None, **kwargs, ) -> tuple[Axes | NDArray[Any], ColorType]: """Plot 2D data. Parameters ---------- data : Holds all data to plot. x, y : Name of column in ``data`` for x/y. grid : If grid should be plotted. Default is True. axes : Axes to re-use, if provided. The default is None (plot on new axis). xlabel, ylabel : Name of the labels. If not provided, we use the markdown equivalent of x/y, if defined in :data:`.markdown`. label : If provided, overrides the legend. Useful when several simulations are shown on the same plot. kwargs : Other keyword passed to the ``pd.DataFrame.plot`` method. Returns ------- axes : Axes | NDArray[Any] Objects created by the ``pd.DataFrame.plot`` method. color : ColorType Color used for the plot. """ if xlabel is None: xlabel = markdown.get(x, x) if ylabel is None: ylabel = markdown.get(y, y) axes = data.plot( x=x, y=y, grid=grid, ax=axes, xlabel=xlabel, ylabel=ylabel, label=label, **kwargs, ) assert axes is not None color = self._get_color_from_last_plot(axes) return axes, color
[docs] def _get_color_from_last_plot( self, axes: Axes | NDArray[Any] ) -> ColorType: """Get the color used for the last plot.""" ax = axes if isinstance(axes, Axes) else axes[-1] assert isinstance(ax, Axes) lines = ax.get_lines() color = lines[-1].get_color() return color
[docs] def hist( self, data: pd.DataFrame, x: PARTICLE_0D_t, bins: int = 200, hist_range: tuple[float, float] | None = None, xlabel: str | None = None, title: str | None = None, **kwargs, ) -> Any: if xlabel is None: xlabel = markdown.get(x, x) axes = data.plot( kind="hist", bins=bins, range=hist_range, xlabel=xlabel, title=title, **kwargs, ) assert axes is not None return axes
[docs] def plot_3d( self, data: Any, key: PARTICLE_3D_t, *args, **kwargs, ) -> Any: self._show_3d = True raise NotImplementedError
[docs] def plot_mesh(self, mesh: vedo.Mesh, *args, **kwargs) -> vedo.Plotter: """Plot the mesh (``STL`` file).""" self._show_3d = True self._plotter_3d += mesh return self._plotter_3d
[docs] def plot_trajectory( self, points: list[NDArray[np.float64]], emission_color: str | None = None, collision_color: str | None = None, collision_point: NDArray[np.float64] = np.array([], dtype=np.float64), lw: int = 7, r: int = 2, **kwargs, ) -> vedo.Plotter: """Plot the :class:`.Particle` trajectory stored in ``points``. Parameters ---------- points : List of positions, as returned by :meth:`.Vector.to_list`. emission_color : If provided, the first known position is colored with this color. collision_color : If provided, the last known position is colored with this color. collision_point : If provided and ``collision_color`` is not ``None``, we plot this point instead of the last of ``points``. This is useful when the extrapolated time is large, and actuel collision point may differ significantly from last position points. lw : Trajectory line width. r : Size of the emission/collision points. """ self._show_3d = True objects = vedo.Lines(points[:-1], points[1:], lw=lw, **kwargs) if emission_color is not None: objects += vedo.Point(pos=points[0], r=r, c=emission_color) if collision_color is not None: if len(collision_point) == 0: collision_point = points[-1] objects += vedo.Point(pos=collision_point, r=r, c=collision_color) self._plotter_3d += objects return self._plotter_3d
[docs] def load_mesh( self, stl_path: str | Path, stl_alpha: float | None = None, **kwargs ) -> vedo.Mesh: mesh = vedo.load(stl_path) if stl_alpha is not None: mesh.alpha(stl_alpha) return mesh
[docs] def show(self) -> None: """Show the plots that were produced. Useful for the bash interface. """ plt.show() if not self._show_3d: return _plotter_3d: vedo.Plotter = self._plotter_3d _plotter_3d.show()
[docs] def _k3d_patch(self) -> None: """Patch ``point_size`` to avoid following error. .. code-block:: File "/home/placais/Documents/simulation/python/simultipac/examples/./analyze_cst_particle_monitor.py", line 61, in <module> result.show() ~~~~~~~~~~~^^ File "/home/placais/Documents/simulation/python/simultipac/src/simultipac/simulation_results/simulation_results.py", line 324, in show return self._plotter.show() ~~~~~~~~~~~~~~~~~~^^ File "/home/placais/Documents/simulation/python/simultipac/src/simultipac/plotter/default.py", line 235, in show _plotter_3d.show() ~~~~~~~~~~~~~~~~^^ File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/vedo/plotter.py", line 3337, in show return backends.get_notebook_backend(self.objects) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^ File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/vedo/backends.py", line 31, in get_notebook_backend return start_k3d(actors2show) File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/vedo/backends.py", line 349, in start_k3d kobj = k3d.points( ia.coordinates.astype(np.float32), ...<5 lines>... name=name, ) File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/k3d/factory.py", line 620, in points Points( ~~~~~~^ positions=positions, ^^^^^^^^^^^^^^^^^^^^ ...<15 lines>... compression_level=compression_level, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ), ^ File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/k3d/objects.py", line 735, in __init__ super(Points, self).__init__(**kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^ File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/k3d/objects.py", line 194, in __init__ super(DrawableWithCallback, self).__init__(**kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^ File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/k3d/objects.py", line 108, in __init__ super(Drawable, self).__init__(**kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^ File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/ipywidgets/widgets/widget.py", line 478, in __init__ super(Widget, self).__init__(**kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^ File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/traitlets/traitlets.py", line 1355, in __init__ setattr(self, key, value) ~~~~~~~^^^^^^^^^^^^^^^^^^ File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/traitlets/traitlets.py", line 716, in __set__ self.set(obj, value) ~~~~~~~~^^^^^^^^^^^^ File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/traitlets/traitlets.py", line 690, in set new_value = self._validate(obj, value) File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/traitlets/traitlets.py", line 722, in _validate value = self.validate(obj, value) File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/traitlets/traitlets.py", line 2460, in validate self.error(obj, value) ~~~~~~~~~~^^^^^^^^^^^^ File "/home/placais/.pyenv/versions/simultipac/lib/python3.13/site-packages/traitlets/traitlets.py", line 831, in error raise TraitError(e) traitlets.traitlets.TraitError: The 'point_size' trait of a Points instance expected a float or a dict, not the float64 np.float64(0.0). This method overrides the default ``k3d.objects.Points`` constructor. May be related to `this issue <https://github.com/marcomusy/vedo/issues/1197>`_. This quick patch seems to raise other errors... So for now, prefer ``"vtk"``. """ logging.info("Applying patch for k3d.") import k3d original_k3d_points = k3d.points def patched_k3d_points(*args, **kwargs) -> k3d.objects.Points: """Instantiate ``k3d`` points with proper ``point_size`` arg.""" ps = kwargs.get("point_size") if ps is None: return original_k3d_points(*args, **kwargs) if ps <= 0.0: logging.info("patching invalid point_size=0.0 -> 1.0.") ps = 1.0 kwargs["point_size"] = ps return original_k3d_points(*args, **kwargs) k3d.points = patched_k3d_points