"""Visualization utilities."""
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from eeg_positions.config import RADIUS_INNER_CONTOUR
from eeg_positions.utils import _get_coords_on_circle
def _plot_spherical_head():
"""Plot a spherical head model.
Returns
-------
fig : matplotlib.figure.Figure
The Figure object.
ax : matplotlib.axes.Axes
The Axes object.
"""
# Start new 3D figure
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
# Add labels, scale limits, equal aspect
ax.set_xlabel("x", fontsize=20)
ax.set_ylabel("y", fontsize=20)
ax.set_zlabel("z", fontsize=20)
ax.set_aspect("auto")
ax.set_xlim((-1, 1))
ax.set_ylim((-1, 1))
ax.set_zlim((-1, 1))
# No background
ax.grid(False)
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor("w")
ax.yaxis.pane.set_edgecolor("w")
ax.zaxis.pane.set_edgecolor("w")
# Plot origin
max_lim = np.max(np.abs([ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()]))
n_pts = 11
fake_spine = np.linspace(-max_lim * 6, max_lim * 6, n_pts)
fake_spine_zeros = np.zeros_like(fake_spine)
ax.plot(fake_spine, fake_spine_zeros, fake_spine_zeros, color="k")
ax.plot(fake_spine_zeros, fake_spine, fake_spine_zeros, color="k")
ax.plot(fake_spine_zeros, fake_spine_zeros, fake_spine, color="k")
# draw spherical head
resolution = 100j
u, v = np.mgrid[0 : 2 * np.pi : resolution, 0 : np.pi : resolution]
x = np.cos(u) * np.sin(v)
y = np.sin(u) * np.sin(v)
z = np.cos(v)
ax.plot_wireframe(x, y, z, color="k", linestyle=":", alpha=0.1)
ax.plot_surface(x, y, z, color="k", alpha=0.1)
ax.set_box_aspect((1, 1, 1))
return fig, ax
def _plot_2d_head(radius_inner_contour=None, show_axis=False):
"""Plot a head in 2D.
Parameters
----------
radius_inner_contour : int | float | None
If int or float, draw a circle with that radius to visualize an inner
contour line. Defaults to None, not drawing a circle. Can instead also
be conveniently set to ``eeg_positions.config.RADIUS_INNER_CONTOUR``,
which is the Fpz-T8-Oz-T7 contour line.
show_axis : bool
Whether or not to show the coordinate system x- and y-axes. Defaults to False.
Returns
-------
fig : matplotlib.figure.Figure
The Figure object.
ax : matplotlib.axes.Axes
The Axes object.
"""
fig, ax = plt.subplots()
ax.axes.set_aspect("equal")
plt.xlabel("x")
plt.ylabel("y")
head_radius = 1.0
linewidth = 1.0
# Draw head shape
head_shape = plt.Circle(
(0, 0), head_radius, color="k", fill=False, linewidth=linewidth
)
ax.add_artist(head_shape)
if radius_inner_contour is not None:
head_shape = plt.Circle(
(0, 0), radius_inner_contour, color="k", fill=False, linewidth=linewidth / 2
)
ax.add_artist(head_shape)
# Draw nose
nose_width = 5
nose_base_l = _get_coords_on_circle(r=head_radius, steps=nose_width)[-1]
nose_base_r = _get_coords_on_circle(r=head_radius, steps=nose_width)[1]
nose_tip = 1.1
ax.plot((nose_base_l[0], 0), (nose_base_l[1], nose_tip), "k", linewidth=linewidth)
ax.plot((nose_base_r[0], 0), (nose_base_r[1], nose_tip), "k", linewidth=linewidth)
ax.vlines(
x=0,
ymin=-1,
ymax=1,
color="black",
linewidth=linewidth / 2,
linestyles="dotted",
)
ax.hlines(
y=0,
xmin=-1,
xmax=1,
color="black",
linewidth=linewidth / 2,
linestyles="dotted",
)
# Adjust limits
ax.set_xlim([-head_radius * 1.1, head_radius * 1.1])
ax.set_ylim([-head_radius * 1.1, head_radius * 1.1])
fig.set_layout_engine("constrained")
if not show_axis:
ax.set_axis_off()
return fig, ax
[docs]
def plot_coords(coords, scatter_kwargs={}, text_kwargs={}):
"""Plot standard EEG electrode coordinates.
Parameters
----------
coords : pandas.DataFrame
The standard EEG electrode coordinates as computed on a sphere.
A pandas DataFrame object with the columns ``"label"``, ``"x"``,
``"y"``, and optionally ``"z"``.
scatter_kwargs : dict
Optional keyword arguments to be passed to the
:meth:`matplotlib.axes.Axes.scatter`
or its 3D variant, depending on the dimensions of `coords`.
text_kwargs : dict
Optional keyword arguments to be passed to the
:meth:`matplotlib.axes.Axes.text`.
Returns
-------
fig : matplotlib.figure.Figure
The Figure object.
ax : matplotlib.axes.Axes
The Axes object.
"""
# input check
if not isinstance(coords, pd.DataFrame):
raise ValueError("`coords` must be a pandas DataFrame object.")
else:
for colname in ["label", "x", "y"]:
if colname not in coords.columns:
raise ValueError(f"`coords` does not have a required column {colname}.")
# What kind of plot to prepare
dim = "3d" if "z" in coords.columns else "2d"
# update kwargs
scatter_settings = dict()
scatter_settings.update(scatter_kwargs)
text_settings = dict(fontsize=6)
text_settings.update(text_kwargs)
if dim == "2d":
fig, ax = _plot_2d_head(RADIUS_INNER_CONTOUR)
ax.scatter(coords["x"], coords["y"], zorder=2.5, **scatter_settings)
for _, row in coords.iterrows():
ax.text(row["x"], row["y"], row["label"], **text_settings)
else:
assert dim == "3d"
fig, ax = _plot_spherical_head()
for _, row in coords.iterrows():
ax.scatter3D(row["x"], row["y"], row["z"], **scatter_settings)
ax.text(row["x"], row["y"], row["z"], row["label"], **text_settings)
return fig, ax