import matplotlib.animation as animation
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle
from numpy.typing import NDArray
[docs]
class ComplexMatrixAnimation:
fig: plt.Figure
ax: plt.Axes
cax: plt.Axes
def __init__(
self,
data: NDArray[np.complex128],
time: NDArray | None = None,
title: str = "Complex Matrix Hinton Plot",
row_names: list[str] | None = None,
col_names: list[str] | None = None,
time_unit: str = "fs",
cmap: str = "hsv",
figshape: tuple[int, int] = (14, 10),
add_text: bool = False,
) -> None:
self.data = data
if time is None:
time = np.arange(data.shape[0], dtype=np.float64)
self.time = time
self.title = title
self.row_names = row_names
self.col_names = col_names
self.time_unit = time_unit
self.figshape = figshape
self.cmap = cmap
self._validate_input()
self.norm = np.abs(self.data).real
self.maxnorm = self.norm.max()
phase = np.angle(self.data) # -pi to pi
self.phase = (phase + 2 * np.pi) % (2 * np.pi) # 0 to 2pi
self.add_text = add_text
@property
def rows(self) -> int:
return self.data.shape[1]
@property
def cols(self) -> int:
return self.data.shape[2]
def _validate_input(self) -> None:
"""Validate input data for complex matrix animation.
Raises:
ValueError: If data is not complex128 or not a square matrix
"""
if (
not isinstance(self.data, np.ndarray)
or self.data.dtype != np.complex128
):
raise ValueError("Input must be a complex128 numpy array")
if not isinstance(self.time, np.ndarray):
raise ValueError("Time must be a numpy array")
if len(self.data.shape) != 3:
raise ValueError("Input must have shape (time, row, column)")
if len(self.time.shape) != 1:
raise ValueError("Time must be a 1D array")
if self.data.shape[0] != self.time.shape[0]:
raise ValueError(
"Time steps must match the first dimension of the data"
)
if (
self.row_names is not None
and len(self.row_names) != self.data.shape[1]
):
raise ValueError(
"Number of row names must match the number of rows in the matrix"
)
if (
self.col_names is not None
and len(self.col_names) != self.data.shape[2]
):
raise ValueError(
"Number of column names must match the number of columns in the matrix"
)
_, rows, cols = self.data.shape
if rows != cols:
raise ValueError(
f"Each frame must be a square matrix, got shape ({rows}, {cols})"
)
[docs]
def set_ax(self, title: str | None = None) -> None:
ax = self.ax
cols = self.cols
rows = self.rows
if title is None:
title = self.title
assert isinstance(title, str)
row_names = self.row_names
col_names = self.col_names
ax.set_title(title, fontsize=24)
ax.set_xlim(-1, cols)
ax.set_ylim(-1, rows)
ax.grid(True)
ax.set_aspect("equal", adjustable="box")
ax.set_xticks(np.arange(cols))
ax.set_yticks(np.arange(rows))
ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
ax.invert_yaxis()
if row_names:
# set fontsize=14
ax.set_yticklabels(row_names, fontsize=16)
if col_names:
ax.set_xticklabels(col_names, fontsize=16)
[docs]
def plot_each_element(
self,
i: int,
j: int,
cmap: plt.Colormap,
norm: np.ndarray,
phase: np.ndarray,
data: np.ndarray,
) -> None:
"""Plot each element of the complex matrix.
Args:
i: Row index
j: Column index
cmap: Colormap object
norm: Magnitude of the complex number
phase: Phase of the complex number
data: Complex matrix data
"""
magnitude = norm[i, j]
phase_value = phase[i, j]
value = data[i, j]
if magnitude > 0:
# Size based on normalized magnitude
# size = (magnitude / self.maxnorm) * 0.95
size = np.sqrt(2.0 * (magnitude / self.maxnorm)) * 0.95
# Color based on phase (normalize from [-π, π] to [0, 1])
color = cmap(phase_value / (2 * np.pi))
# Create and add rectangle
rect = Rectangle(
(j - size / 2, i - size / 2),
size,
size,
facecolor=color,
edgecolor="gray",
)
self.ax.add_patch(rect)
# Add text annotation
if self.add_text:
text = f"{value: .2f}"
self.ax.text(
j,
i,
text,
horizontalalignment="center",
verticalalignment="center",
fontsize=8,
color="white",
bbox=dict(facecolor="black", alpha=0.7, edgecolor="none"),
)
[docs]
def update(self, frame_num: int) -> None:
"""Update function for animation.
Args:
frame_num: Frame number
"""
# Get current frame data
frame_data = self.data[frame_num]
rows, cols = frame_data.shape
time = self.time[frame_num]
title = f"{self.title} {time: .2f} {self.time_unit}"
self.ax.clear()
self.set_ax(title)
_cmap = plt.get_cmap(self.cmap)
norm = self.norm[frame_num]
phase = self.phase[frame_num]
# Plot each element
for i in range(rows):
for j in range(cols):
self.plot_each_element(i, j, _cmap, norm, phase, frame_data)
[docs]
def set_cyclic_colorbar(self) -> mcolors.Colormap:
ax = self.cax
theta = np.linspace(0.0, 2 * np.pi, 100)
r = np.linspace(0, 1, 100)
Theta, R = np.meshgrid(theta, r)
cmap = plt.get_cmap(self.cmap)
norm = mcolors.Normalize(vmin=0.0, vmax=2 * np.pi)
ax.set_xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2])
ax.set_xticklabels(
["0", r"$\frac{\pi}{2}$", r"$\pi$", r"$\frac{3\pi}{2}$"],
fontsize=14,
)
ax.set_yticks([])
ax.pcolormesh(
Theta, R, Theta, cmap=cmap, norm=norm
) # , shading="auto", alpha=R / R.max())
return cmap
[docs]
def create_animation(
self,
interval: int = 200,
) -> tuple[plt.Figure, animation.FuncAnimation]:
"""Create an animation of complex matrix Hinton plots.
Args:
data: Complex array of shape (time, row, column)
interval: Time interval between frames in milliseconds
Returns:
Figure and Animation objects
"""
self.setup_figure()
# Create animation
anim = animation.FuncAnimation(
self.fig,
self.update, # type: ignore
frames=self.data.shape[0],
fargs=(),
interval=interval,
blit=False,
)
return self.fig, anim
[docs]
def save_animation(
anim: animation.FuncAnimation,
filename: str = "animation.gif",
fps: int = 5,
dpi: int = 100,
) -> None:
"""Save animation as a GIF file.
Args:
anim: Animation object
filename: Output filename
fps: Frames per second
dpi: Dots per inch for the output
"""
print(f"Saving animation to {filename}...")
writer = animation.PillowWriter(fps=fps)
anim.save(filename, writer=writer, dpi=dpi)
print("Animation saved successfully!")
[docs]
def get_anim(
data: NDArray[np.complex128],
time: NDArray | None = None,
title: str = "Density Matrix Evolution",
row_names: list[str] | None = None,
col_names: list[str] | None = None,
time_unit: str = "fs",
save_gif: bool = False,
gif_filename: str = "animation.gif",
cmap: str = "hsv",
fps: int = 5,
dpi: int = 100,
add_text: bool = False,
) -> tuple[plt.Figure, animation.FuncAnimation]:
"""Main function to create Hinton plot animation from complex matrix data.
Args:
data (NDArray[np.complex128]): Complex array of shape (time, row, column).
time (NDArray | None, optional): Array of time points corresponding to the data. Defaults to None.
title (str, optional): Title of the plot. Defaults to "Complex Matrix Hinton Plot".
row_names (list[str] | None, optional): List of row names. Defaults to None.
col_names (list[str] | None, optional): List of column names. Defaults to None.
time_unit (str, optional): Unit of time to display on the plot. Defaults to "".
save_gif (bool, optional): Whether to save the animation as a GIF. Defaults to False.
gif_filename (str, optional): Output filename for GIF. Defaults to "animation.gif".
cmap (str, optional): Colormap to use for the plot. Defaults to "hsv".
`cmap` should be cyclic such as 'twilight', 'twilight_shifted', 'hsv'.
See also https://matplotlib.org/stable/users/explain/colors/colormaps.html#cyclic.
fps (int, optional): Frames per second for GIF. Defaults to 5.
dpi (int, optional): Dots per inch for the output GIF. Defaults to 100.
add_text (bool, optional): Display matrix_element or not. Defaults to False.
Returns:
tuple[plt.Figure, animation.FuncAnimation]: Figure and Animation objects.
Example:
>>> # Create a 3x3 complex matrix that evolves over 10 time steps
>>> t = np.linspace(0, 2*np.pi, 10)
>>> data = np.zeros((10, 3, 3), dtype=np.complex128)
>>> for i in range(10):
... data[i] = np.exp(1j * t[i]) * np.random.random((3, 3))
>>> fig, anim = main(data, time=t, save_gif=True)
>>> plt.show()
"""
# Create animation object
anim_obj = ComplexMatrixAnimation(
data, time, title, row_names, col_names, time_unit, cmap=cmap, add_text=add_text
)
# Create animation
fig, anim = anim_obj.create_animation()
if save_gif:
save_animation(anim, gif_filename, fps, dpi)
return fig, anim
if __name__ == "__main__":
# Create example data: rotating complex numbers
time_steps = 20
size = 5
t = np.linspace(0, 2 * np.pi, time_steps)
# Initialize complex matrix
test_data = np.zeros((time_steps, size, size), dtype=np.complex128)
# Create rotating complex numbers with varying magnitudes
for i in range(time_steps):
magnitude = (
np.random.random((size, size)) + 0.5
) # Random magnitudes > 0.5
phase = (
t[i] + np.random.random((size, size)) * np.pi / 4
) # Base rotation + noise
test_data[i] = magnitude * np.exp(1j * phase)
# Create animation and save as GIF
print("Creating animation...")
fig, anim = get_anim(
test_data,
time=t,
title="Density Matrix Evolution",
save_gif=True,
gif_filename="complex_matrix.gif",
cmap="twilight_shifted",
time_unit="fs",
row_names=["$|" + f"{i}" + r"\rangle$" for i in range(size)],
col_names=[r"$\langle" + f"{i}" + r"|$" for i in range(size)],
fps=5,
dpi=100,
)
plt.show()