import matplotlib.animation as animation
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle
from numpy.typing import NDArray
[docs]
class ComplexMatrixAnimation:
    fig: plt.Figure
    ax: plt.Axes
    cax: plt.Axes
    def __init__(
        self,
        data: NDArray[np.complex128],
        time: NDArray | None = None,
        title: str = "Complex Matrix Hinton Plot",
        row_names: list[str] | None = None,
        col_names: list[str] | None = None,
        time_unit: str = "fs",
        cmap: str = "hsv",
        figshape: tuple[int, int] = (14, 10),
        add_text: bool = False,
    ) -> None:
        self.data = data
        if time is None:
            time = np.arange(data.shape[0], dtype=np.float64)
        self.time = time
        self.title = title
        self.row_names = row_names
        self.col_names = col_names
        self.time_unit = time_unit
        self.figshape = figshape
        self.cmap = cmap
        self._validate_input()
        self.norm = np.abs(self.data).real
        self.maxnorm = self.norm.max()
        phase = np.angle(self.data)  # -pi to pi
        self.phase = (phase + 2 * np.pi) % (2 * np.pi)  # 0 to 2pi
        self.add_text = add_text
    @property
    def rows(self) -> int:
        return self.data.shape[1]
    @property
    def cols(self) -> int:
        return self.data.shape[2]
    def _validate_input(self) -> None:
        """Validate input data for complex matrix animation.
        Raises:
            ValueError: If data is not complex128 or not a square matrix
        """
        if (
            not isinstance(self.data, np.ndarray)
            or self.data.dtype != np.complex128
        ):
            raise ValueError("Input must be a complex128 numpy array")
        if not isinstance(self.time, np.ndarray):
            raise ValueError("Time must be a numpy array")
        if len(self.data.shape) != 3:
            raise ValueError("Input must have shape (time, row, column)")
        if len(self.time.shape) != 1:
            raise ValueError("Time must be a 1D array")
        if self.data.shape[0] != self.time.shape[0]:
            raise ValueError(
                "Time steps must match the first dimension of the data"
            )
        if (
            self.row_names is not None
            and len(self.row_names) != self.data.shape[1]
        ):
            raise ValueError(
                "Number of row names must match the number of rows in the matrix"
            )
        if (
            self.col_names is not None
            and len(self.col_names) != self.data.shape[2]
        ):
            raise ValueError(
                "Number of column names must match the number of columns in the matrix"
            )
        _, rows, cols = self.data.shape
        if rows != cols:
            raise ValueError(
                f"Each frame must be a square matrix, got shape ({rows}, {cols})"
            )
[docs]
    def set_ax(self, title: str | None = None) -> None:
        ax = self.ax
        cols = self.cols
        rows = self.rows
        if title is None:
            title = self.title
        assert isinstance(title, str)
        row_names = self.row_names
        col_names = self.col_names
        ax.set_title(title, fontsize=24)
        ax.set_xlim(-1, cols)
        ax.set_ylim(-1, rows)
        ax.grid(True)
        ax.set_aspect("equal", adjustable="box")
        ax.set_xticks(np.arange(cols))
        ax.set_yticks(np.arange(rows))
        ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)
        ax.invert_yaxis()
        if row_names:
            # set fontsize=14
            ax.set_yticklabels(row_names, fontsize=16)
        if col_names:
            ax.set_xticklabels(col_names, fontsize=16) 
[docs]
    def plot_each_element(
        self,
        i: int,
        j: int,
        cmap: plt.Colormap,
        norm: np.ndarray,
        phase: np.ndarray,
        data: np.ndarray,
    ) -> None:
        """Plot each element of the complex matrix.
        Args:
            i: Row index
            j: Column index
            cmap: Colormap object
            norm: Magnitude of the complex number
            phase: Phase of the complex number
            data: Complex matrix data
        """
        magnitude = norm[i, j]
        phase_value = phase[i, j]
        value = data[i, j]
        if magnitude > 0:
            # Size based on normalized magnitude
            # size = (magnitude / self.maxnorm) * 0.95
            size = np.sqrt(2.0 * (magnitude / self.maxnorm)) * 0.95
            # Color based on phase (normalize from [-π, π] to [0, 1])
            color = cmap(phase_value / (2 * np.pi))
            # Create and add rectangle
            rect = Rectangle(
                (j - size / 2, i - size / 2),
                size,
                size,
                facecolor=color,
                edgecolor="gray",
            )
            self.ax.add_patch(rect)
            # Add text annotation
            if self.add_text:
                text = f"{value: .2f}"
                self.ax.text(
                    j,
                    i,
                    text,
                    horizontalalignment="center",
                    verticalalignment="center",
                    fontsize=8,
                    color="white",
                    bbox=dict(facecolor="black", alpha=0.7, edgecolor="none"),
                ) 
[docs]
    def update(self, frame_num: int) -> None:
        """Update function for animation.
        Args:
            frame_num: Frame number
        """
        # Get current frame data
        frame_data = self.data[frame_num]
        rows, cols = frame_data.shape
        time = self.time[frame_num]
        title = f"{self.title} {time: .2f} {self.time_unit}"
        self.ax.clear()
        self.set_ax(title)
        _cmap = plt.get_cmap(self.cmap)
        norm = self.norm[frame_num]
        phase = self.phase[frame_num]
        # Plot each element
        for i in range(rows):
            for j in range(cols):
                self.plot_each_element(i, j, _cmap, norm, phase, frame_data) 
[docs]
    def set_cyclic_colorbar(self) -> mcolors.Colormap:
        ax = self.cax
        theta = np.linspace(0.0, 2 * np.pi, 100)
        r = np.linspace(0, 1, 100)
        Theta, R = np.meshgrid(theta, r)
        cmap = plt.get_cmap(self.cmap)
        norm = mcolors.Normalize(vmin=0.0, vmax=2 * np.pi)
        ax.set_xticks([0, np.pi / 2, np.pi, 3 * np.pi / 2])
        ax.set_xticklabels(
            ["0", r"$\frac{\pi}{2}$", r"$\pi$", r"$\frac{3\pi}{2}$"],
            fontsize=14,
        )
        ax.set_yticks([])
        ax.pcolormesh(
            Theta, R, Theta, cmap=cmap, norm=norm
        )  # , shading="auto", alpha=R / R.max())
        return cmap 
[docs]
    def create_animation(
        self,
        interval: int = 200,
    ) -> tuple[plt.Figure, animation.FuncAnimation]:
        """Create an animation of complex matrix Hinton plots.
        Args:
            data: Complex array of shape (time, row, column)
            interval: Time interval between frames in milliseconds
        Returns:
            Figure and Animation objects
        """
        self.setup_figure()
        # Create animation
        anim = animation.FuncAnimation(
            self.fig,
            self.update,  # type: ignore
            frames=self.data.shape[0],
            fargs=(),
            interval=interval,
            blit=False,
        )
        return self.fig, anim 
 
[docs]
def save_animation(
    anim: animation.FuncAnimation,
    filename: str = "animation.gif",
    fps: int = 5,
    dpi: int = 100,
) -> None:
    """Save animation as a GIF file.
    Args:
        anim: Animation object
        filename: Output filename
        fps: Frames per second
        dpi: Dots per inch for the output
    """
    print(f"Saving animation to {filename}...")
    writer = animation.PillowWriter(fps=fps)
    anim.save(filename, writer=writer, dpi=dpi)
    print("Animation saved successfully!") 
[docs]
def get_anim(
    data: NDArray[np.complex128],
    time: NDArray | None = None,
    title: str = "Density Matrix Evolution",
    row_names: list[str] | None = None,
    col_names: list[str] | None = None,
    time_unit: str = "fs",
    save_gif: bool = False,
    gif_filename: str = "animation.gif",
    cmap: str = "hsv",
    fps: int = 5,
    dpi: int = 100,
    add_text: bool = False,
) -> tuple[plt.Figure, animation.FuncAnimation]:
    """Main function to create Hinton plot animation from complex matrix data.
    Args:
        data (NDArray[np.complex128]): Complex array of shape (time, row, column).
        time (NDArray | None, optional): Array of time points corresponding to the data. Defaults to None.
        title (str, optional): Title of the plot. Defaults to "Complex Matrix Hinton Plot".
        row_names (list[str] | None, optional): List of row names. Defaults to None.
        col_names (list[str] | None, optional): List of column names. Defaults to None.
        time_unit (str, optional): Unit of time to display on the plot. Defaults to "".
        save_gif (bool, optional): Whether to save the animation as a GIF. Defaults to False.
        gif_filename (str, optional): Output filename for GIF. Defaults to "animation.gif".
        cmap (str, optional): Colormap to use for the plot. Defaults to "hsv".
            `cmap` should be cyclic such as 'twilight', 'twilight_shifted', 'hsv'.
            See also https://matplotlib.org/stable/users/explain/colors/colormaps.html#cyclic.
        fps (int, optional): Frames per second for GIF. Defaults to 5.
        dpi (int, optional): Dots per inch for the output GIF. Defaults to 100.
        add_text (bool, optional): Display matrix_element or not. Defaults to False.
    Returns:
        tuple[plt.Figure, animation.FuncAnimation]: Figure and Animation objects.
    Example:
        >>> # Create a 3x3 complex matrix that evolves over 10 time steps
        >>> t = np.linspace(0, 2*np.pi, 10)
        >>> data = np.zeros((10, 3, 3), dtype=np.complex128)
        >>> for i in range(10):
        ...     data[i] = np.exp(1j * t[i]) * np.random.random((3, 3))
        >>> fig, anim = main(data, time=t, save_gif=True)
        >>> plt.show()
    """
    # Create animation object
    anim_obj = ComplexMatrixAnimation(
        data, time, title, row_names, col_names, time_unit, cmap=cmap, add_text=add_text
    )
    # Create animation
    fig, anim = anim_obj.create_animation()
    if save_gif:
        save_animation(anim, gif_filename, fps, dpi)
    return fig, anim 
if __name__ == "__main__":
    # Create example data: rotating complex numbers
    time_steps = 20
    size = 5
    t = np.linspace(0, 2 * np.pi, time_steps)
    # Initialize complex matrix
    test_data = np.zeros((time_steps, size, size), dtype=np.complex128)
    # Create rotating complex numbers with varying magnitudes
    for i in range(time_steps):
        magnitude = (
            np.random.random((size, size)) + 0.5
        )  # Random magnitudes > 0.5
        phase = (
            t[i] + np.random.random((size, size)) * np.pi / 4
        )  # Base rotation + noise
        test_data[i] = magnitude * np.exp(1j * phase)
    # Create animation and save as GIF
    print("Creating animation...")
    fig, anim = get_anim(
        test_data,
        time=t,
        title="Density Matrix Evolution",
        save_gif=True,
        gif_filename="complex_matrix.gif",
        cmap="twilight_shifted",
        time_unit="fs",
        row_names=["$|" + f"{i}" + r"\rangle$" for i in range(size)],
        col_names=[r"$\langle" + f"{i}" + r"|$" for i in range(size)],
        fps=5,
        dpi=100,
    )
    plt.show()