Source code for viz

"""
viz.py

A module that defines the Viz class for creating and manipulating
matplotlib-based visualizations with a consistent interface. Includes
plotting methods (e.g., bar, line, scatter), layout utilities, and
tools for combining multiple plots into a single figure.
"""

from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import numpy as np


class _VizCore:
    def save(self, path, **kwargs):
        """
        Saves the figure to a file.

        Parameters
        ----------
        path : str
            The file path to save the figure.
        kwargs : dict, optional
            Additional keyword arguments passed to `fig.savefig()`, such as:
            - dpi : int, optional (dots per inch for image resolution)
            - bbox_inches : str or 'tight', optional (to adjust bounding box)
            - transparent : bool, optional (if True, the background is transparent)
        """
        self.fig.savefig(path, **kwargs)

    def show(self, clear=False):
        """
        Displays the plot.

        Parameters
        ----------
        clear : bool, optional
            If True, the previous output is cleared before showing the plot.
        """
        if clear:
            clear_output(wait=True)
        try:
            display(self.fig)
        except NameError:
            plt.show()

    def clear(self):
        """
        Clears the current axis.

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.cla()
        return self

    def twinx(self):
        """
        Creates a twin axis sharing the same x-axis but different y-axis.

        Returns
        -------
        Viz
            A new Viz object with the twin axis.
        """
        twin_ax = self.ax.twinx()
        return Viz(twin_ax, self.fig)

    def imshow(self, *args, **kwargs):
        """
        Displays an image on the plot.

        Parameters
        ----------
        args : tuple
            The image data (e.g., a 2D array representing pixel intensities).
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.imshow()`, such as:
            - cmap : str, optional (colormap for displaying the image)
            - interpolation : str, optional (method for interpolation, e.g., 'nearest', 'bilinear')
            - alpha : float, optional (transparency, from 0 to 1)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.imshow(*args, **kwargs)
        return self

    @staticmethod
    def combine_viz(viz_list, nrows=None, ncols=None):
        """
        Combines multiple Viz objects into a single figure with subplots.

        Parameters
        ----------
        viz_list : list
            A list of Viz objects to combine.
        nrows : int, optional
            The number of rows for the subplot grid (default is None).
        ncols : int, optional
            The number of columns for the subplot grid (default is None).

        Returns
        -------
        Viz
            A new Viz object containing the combined plots.
        """
        if nrows is None or ncols is None:
            # If grid size is not provided, calculate it based on the length of viz_list
            total_plots = len(viz_list)
            ncols = int(np.ceil(np.sqrt(total_plots)))  # Approx square grid
            nrows = int(np.ceil(total_plots / ncols))  # Make sure all plots fit

        # Create a new figure and subplot grid
        fig, axs = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 5))

        # Flatten the axes for easy iteration (if it's a multi-dimensional grid)
        axs = axs.flatten() if nrows * ncols > 1 else [axs]

        # Iterate over the list of Viz objects and corresponding axes
        for i, viz in enumerate(viz_list):
            # Get the current axis to plot on
            ax = axs[i]

            # Transfer the plot to the new axis (copying the properties)
            # viz_copy = Viz(ax, fig)  # Create a new Viz instance with the current subplot axis
            # Copy the plot and other settings from the original Viz object to the new axis
            ax.set_title(viz.ax.get_title())  # Copy the title
            ax.set_xlabel(viz.ax.get_xlabel())  # Copy the xlabel
            ax.set_ylabel(viz.ax.get_ylabel())  # Copy the ylabel
            # Copy grid visibility and style
            xgridlines = viz.ax.get_xgridlines()
            ygridlines = viz.ax.get_ygridlines()
            gridlines = xgridlines + ygridlines

            if any(line.get_visible() for line in gridlines):
                gridline = next(
                    (line for line in gridlines if line.get_visible()), None
                )
                if gridline:
                    ax.grid(
                        True,
                        linestyle=gridline.get_linestyle(),
                        color=gridline.get_color(),
                        linewidth=gridline.get_linewidth(),
                    )
            else:
                ax.grid(False)
            # Copy other properties like lines, scatter, etc., based on what the viz object has
            for line in viz.ax.lines:
                ax.plot(line.get_xdata(), line.get_ydata(), label=line.get_label())
            for scatter in viz.ax.collections:
                ax.scatter(
                    scatter.get_offsets()[:, 0],
                    scatter.get_offsets()[:, 1],
                    label=scatter.get_label(),
                )

        # Adjust layout to avoid overlap
        # fig.tight_layout()
        plt.close(fig)
        # Return a new Viz instance wrapping the combined figure
        return Viz(axs[0], fig)

    def close(self):
        """
        Closes the figure.

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        plt.close(self.fig)
        return self




class _LayoutMixin:
    def set_title(self, txt, **kwargs):
        """
        Sets the title of the plot.

        Parameters
        ----------
        txt : str
            The title text.
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.set_title()`, such as:
            - fontsize : int or float, optional
            - fontweight : {'normal', 'bold', 'heavy', 'light', 'ultrabold', 'ultralight'}, optional
            - color : str, optional (e.g., 'red', 'blue', etc.)
            - pad : float, optional (distance from the top of the axes)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.set_title(txt, **kwargs)
        return self

    def xlabel(self, txt, **kwargs):
        """
        Sets the label for the x-axis.

        Parameters
        ----------
        txt : str
            The label text for the x-axis.
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.set_xlabel()`, such as:
            - fontsize : int or float, optional
            - fontweight : {'normal', 'bold', 'heavy', 'light', 'ultrabold', 'ultralight'}, optional
            - color : str, optional (e.g., 'red', 'blue', etc.)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.set_xlabel(txt, **kwargs)
        return self

    def ylabel(self, txt, **kwargs):
        """
        Sets the label for the y-axis.

        Parameters
        ---------
        txt : str
            The label text for the y-axis.
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.set_ylabel()`, such as:
            - fontsize : int or float, optional
            - fontweight : {'normal', 'bold', 'heavy', 'light', 'ultrabold', 'ultralight'}, optional
            - color : str, optional (e.g., 'red', 'blue', etc.)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.set_ylabel(txt, **kwargs)
        return self

    def legend(self, **kwargs):
        """
        Adds a legend to the plot.

        Parameters
        ----------
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.legend()`, such as:
            - loc : str or int, optional (location of the legend, e.g., 'best', 'upper left', 0)
            - fontsize : int or float, optional
            - title : str, optional (title of the legend)
            - shadow : bool, optional (whether to add shadow)
            - bbox_to_anchor : tuple, optional (to specify a custom position for the legend)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.legend(**kwargs)
        return self

    def grid(self, flag=True, **kwargs):
        """
        Enables or disables the grid on the plot.

        Parameters
        ----------
        flag : bool, optional, default True
            If True, the grid is enabled, otherwise it is disabled.
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.grid()`, such as:
            - color : str, optional (e.g., 'gray', 'blue', etc.)
            - linestyle : {'-', '--', '-.', ':'}, optional
            - linewidth : float, optional (line thickness)
            - which : {'major', 'minor'}, optional (gridlines for major or minor ticks)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.grid(flag, **kwargs)
        return self

    def tight_layout(self, **kwargs):
        """
        Adjusts the layout to prevent overlap of plot elements.

        Parameters
        ----------
        kwargs : dict, optional
            Additional keyword arguments passed to `fig.tight_layout()`, such as:
            - pad : float, optional (padding between plot elements)
            - h_pad : float, optional (height padding)
            - w_pad : float, optional (width padding)
            - rect : tuple, optional (the area to which the layout is confined,
            e.g., (left, bottom, right, top))

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.fig.tight_layout(**kwargs)
        return self

    def suptitle(self, txt, **kwargs):
        """
        Sets the title for the entire figure.

        Parameters
        ----------
        txt : str
            The title text.
        kwargs : dict, optional
            Additional keyword arguments passed to `fig.suptitle()`, such as:
            - fontsize : int or float, optional (size of the title text)
            - fontweight : {'normal', 'bold', 'heavy', 'light', 'ultrabold', 'ultralight'}, optional
            - color : str, optional (e.g., 'red', 'blue', etc.)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.fig.suptitle(txt, **kwargs)
        return self

    def set_xticks(self, ticks, **kwargs):
        """
        Sets the ticks on the x-axis.

        Parameters
        ----------
        ticks : list
            A list of positions where ticks should appear on the x-axis.
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.set_xticks()`, such as:
            - minor : bool, optional (if True, the minor ticks are set instead of the major ones)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.set_xticks(ticks, **kwargs)
        return self

    def set_yticks(self, ticks, **kwargs):
        """
        Sets the ticks on the y-axis.

        Parameters
        ----------
        ticks : list
            A list of positions where ticks should appear on the y-axis.
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.set_yticks()`, such as:
            - minor : bool, optional (if True, the minor ticks are set instead of the major ones)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.set_yticks(ticks, **kwargs)
        return self

    def invert_x(self):
        """
        Inverts the x-axis.

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.invert_xaxis()
        return self

    def invert_y(self):
        """
        Inverts the y-axis.

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.invert_yaxis()
        return self

    def set_xlim(self, *args, **kwargs):
        """
        Sets the limits for the x-axis.

        Parameters
        ----------
        args : tuple
            The limits to set as (min, max).
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.set_xlim()`, such as:
            - xmin : float, optional (minimum limit for x-axis)
            - xmax : float, optional (maximum limit for x-axis)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.set_xlim(*args, **kwargs)
        return self

    def set_ylim(self, *args, **kwargs):
        """
        Sets the limits for the y-axis.

        Parameters
        ----------
        args : tuple
            The limits to set as (min, max).
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.set_ylim()`, such as:
            - ymin : float, optional (minimum limit for y-axis)
            - ymax : float, optional (maximum limit for y-axis)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.set_ylim(*args, **kwargs)
        return self

    def annotate(self, *args, **kwargs):
        """
        Adds an annotation to the plot.

        Parameters
        ----------
        args : tuple
            The annotation arguments, typically (text, xy).
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.annotate()`, such as:
            - xytext : tuple, optional (position of annotation text)
            - arrowprops : dict, optional (properties of the arrow, e.g., {'arrowstyle': '->'})
            - fontsize : int, optional (font size of the annotation text)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.annotate(*args, **kwargs)
        return self

    def style(self, style_name="seaborn-v0_8-whitegrid"):
        """
        Applies a matplotlib style to the plot.

        Parameters
        ----------
        style_name : str, optional, default='seaborn-v0_8-whitegrid'
            The style to apply. For example, 'seaborn-darkgrid', 'ggplot', etc.

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        plt.style.use(style_name)
        return self

    def figsize(self, size):
        """
        Sets the figure size.

        Parameters
        ----------
        size : tuple
            The size of the figure as (width, height) in inches.

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.fig.set_size_inches(*size, forward=True)
        return self

    def aspect(self, value="auto"):
        """
        Sets the aspect ratio of the plot.

        Parameters
        ----------
        value : str or float, optional, default='auto'
            The aspect ratio to set:
            - 'auto' (default): automatic aspect ratio based on the data
            - 'equal': equal scaling on both axes
            - float: fixed aspect ratio, e.g., 1.0 for equal scaling
            - 'scaled': scaled based on the data range

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.set_aspect(value)
        return self


class _PlotMixin:

    def plot(self, *args, **kwargs):
        """
        Plots data on the axis.

        Parameters
        ----------
        args : tuple
            The data to plot. The first element is typically the x-data,
            and the second is the y-data.
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.plot()`, such as:
            - label : str, optional (label for the plot, used in legend)
            - linestyle : {'-', '--', '-.', ':'}, optional
            - color : str, optional (e.g., 'red', 'blue', etc.)
            - marker : {'o', 'x', 's', '^', etc.}, optional (marker style)
            - linewidth : float, optional (line thickness)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.plot(*args, **kwargs)
        return self

    def scatter(self, *args, **kwargs):
        """
        Creates a scatter plot.

        Parameters
        ----------
        args : tuple
            The data to plot as scatter. The first element is typically
            the x-data, and the second is the y-data.
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.scatter()`, such as:
            - color : str, optional (e.g., 'red', 'blue', etc.)
            - marker : {'o', 'x', 's', '^', etc.}, optional (marker style)
            - s : scalar or array-like, optional (size of markers)
            - alpha : float, optional (transparency, from 0 to 1)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.scatter(*args, **kwargs)
        return self

    def bar(self, *args, **kwargs):
        """
        Creates a bar plot.

        Parameters
        ----------
        args : tuple
            The data to plot as bars. The first element is the x-data (positions),
            and the second is the y-data (height).
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.bar()`, such as:
            - color : str, optional (e.g., 'red', 'blue', etc.)
            - width : float, optional (width of bars)
            - align : {'center', 'edge'}, optional (alignment of bars)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.bar(*args, **kwargs)
        return self

    def hlines(self, *args, **kwargs):
        """
        Draws horizontal lines across the plot.

        Parameters
        ----------
        args : tuple
            Arguments passed to `ax.hlines()`, typically:
            - y : scalar or array-like (y positions of the lines)
            - xmin : scalar, optional (left limit for the line)
            - xmax : scalar, optional (right limit for the line)
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.hlines()`, such as:
            - color : str, optional (line color)
            - linewidth : float, optional (thickness of the line)
            - linestyle : {'-', '--', '-.', ':'}, optional (line style)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.hlines(*args, **kwargs)
        return self

    def vlines(self, *args, **kwargs):
        """
        Draws vertical lines across the plot.

        Parameters
        ----------
        args : tuple
            Arguments passed to `ax.vlines()`, typically:
            - x : scalar or array-like (x positions of the lines)
            - ymin : scalar, optional (bottom limit for the line)
            - ymax : scalar, optional (top limit for the line)
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.vlines()`, such as:
            - color : str, optional (line color)
            - linewidth : float, optional (thickness of the line)
            - linestyle : {'-', '--', '-.', ':'}, optional (line style)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.vlines(*args, **kwargs)
        return self

    def contour(self, *args, **kwargs):
        """
        Creates a contour plot.

        Parameters
        ----------
        args : tuple
            The contour data, typically (X, Y, Z) where Z represents the contour levels.
        kwargs : dict, optional
            Additional keyword arguments passed to `ax.contour()`, such as:
            - levels : int or array-like, optional (specific contour levels)
            - colors : str or array-like, optional (colors for the contours)
            - linewidths : float, optional (width of contour lines)

        Returns
        -------
        self : Viz
            The Viz object itself, allowing for method chaining.
        """
        self.ax.contour(*args, **kwargs)
        return self

    def add_subplot(self, *args, **kwargs):
        """
        Adds a new subplot to the figure.

        Parameters
        ----------
        args : tuple
            Arguments for `fig.add_subplot()`, such as (nrows, ncols, index).
        kwargs : dict, optional
            Additional keyword arguments passed to `fig.add_subplot()`.

        Returns
        -------
        Viz
            A new Viz object wrapping the new subplot.
        """
        ax = self.fig.add_subplot(*args, **kwargs)
        return Viz(ax, self.fig)


[docs] class Viz(_PlotMixin, _LayoutMixin, _VizCore): """ Viz class for plotting on a matplotlib axis. Parameters ---------- ax : matplotlib.axes.Axes The axis on which the plot will be drawn. fig : matplotlib.figure.Figure, optional The figure containing the axis. Defaults to None, in which case ax.figure is used. Methods ------- add_subplot(*args, **kwargs) Adds a new subplot to the figure. """ def __init__(self, ax=None, fig=None): """ Initializes the Viz object with a given axis and optional figure. Parameters ---------- ax : matplotlib.axes.Axes The axis on which the plot will be drawn. fig : matplotlib.figure.Figure, optional The figure containing the axis (default is None, which means it uses ax.figure). """ if ax is None: fig, ax = plt.subplots() self.ax = ax self.fig = fig or ax.figure def __getattr__(self, attr): """ Retrieves attributes of the underlying axis. Parameters ---------- attr : str The name of the attribute to retrieve. Returns ------- method : function The method of the underlying axis for the given attribute. """ if hasattr(self.ax, attr): def method(*args, **kwargs): result = getattr(self.ax, attr)(*args, **kwargs) return self if result is None else result return method raise AttributeError(f"'PlotWrapper' has no attribute '{attr}'") def __dir__(self): """ Returns a list of the attributes and methods available for the Viz object. """ return sorted(set(super().__dir__()) | set(dir(self.ax))) def __enter__(self): """ Initializes the Viz object for use in a context manager (e.g., with `with` statement). """ return self def __exit__(self, exc_type, exc_val, exc_tb): """ Displays the plot when exiting a context manager. """ self.show() def __getitem__(self, key): """ Retrieves the item from the axis using the provided key. Parameters ---------- key : index or key The key or index for the item. Returns ------- item : object The item from the axis corresponding to the key. """ return self.ax[key]