Source code for figrecipe._graph._core

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2026-01-10 (ywatanabe)"
"""Core graph visualization module for FigRecipe.

Provides publication-quality graph/network visualizations compatible with
scitex styling (40mm width, 6pt fonts) and interactive HTML export.
"""

from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
from matplotlib.axes import Axes

LAYOUTS = {
    "spring": "spring_layout",
    "circular": "circular_layout",
    "kamada_kawai": "kamada_kawai_layout",
    "shell": "shell_layout",
    "spectral": "spectral_layout",
    "random": "random_layout",
    "planar": "planar_layout",
    "spiral": "spiral_layout",
}


def _get_layout(
    G, layout: str, pos: Optional[Dict] = None, seed: int = 42, **layout_kwargs
):
    """Compute node positions using specified layout algorithm.

    Parameters
    ----------
    G : networkx.Graph
        The graph to layout.
    layout : str
        Layout algorithm name: 'spring', 'circular', 'kamada_kawai', 'shell',
        'spectral', 'random', 'planar', 'spiral', or 'hierarchical'.
    pos : dict, optional
        Pre-computed positions {node: (x, y)}. If provided, returned as-is.
    seed : int
        Random seed for reproducibility.
    **layout_kwargs
        Additional kwargs passed to the layout function.

    Returns
    -------
    dict
        Node positions {node: (x, y)}.
    """
    import networkx as nx

    if pos is not None:
        return pos

    if layout == "hierarchical":
        # For DAGs, use multipartite layout based on topological order
        if nx.is_directed_acyclic_graph(G):
            try:
                for layer, node in enumerate(nx.topological_sort(G)):
                    G.nodes[node]["subset"] = layer
                return nx.multipartite_layout(G, **layout_kwargs)
            except nx.NetworkXError:
                pass
        # Fallback to spring for non-DAGs
        return nx.spring_layout(G, seed=seed, **layout_kwargs)

    layout_func_name = LAYOUTS.get(layout, "spring_layout")
    layout_func = getattr(nx, layout_func_name)

    # Handle layouts that support seed parameter
    if layout in ("spring", "random"):
        layout_kwargs.setdefault("seed", seed)

    try:
        return layout_func(G, **layout_kwargs)
    except Exception:
        # Fallback to spring layout
        return nx.spring_layout(G, seed=seed)


def _resolve_node_attr(G, attr: Union[str, Callable, Any], default: Any = None) -> List:
    """Resolve node attribute values from name, callable, or scalar.

    Parameters
    ----------
    G : networkx.Graph
        The graph.
    attr : str, callable, or scalar
        - str: Node attribute name to look up
        - callable: Function (node, data) -> value
        - scalar: Single value for all nodes
    default : any
        Default value if attribute not found.

    Returns
    -------
    list
        List of values for each node in G.nodes() order.
    """
    if attr is None:
        return [default] * len(G.nodes())

    if callable(attr):
        return [attr(n, G.nodes[n]) for n in G.nodes()]

    if isinstance(attr, str):
        return [G.nodes[n].get(attr, default) for n in G.nodes()]

    # List/array pass-through (used for replay with pre-computed values)
    if isinstance(attr, (list, tuple, np.ndarray)):
        return list(attr)

    # Scalar value
    return [attr] * len(G.nodes())


def _resolve_edge_attr(G, attr: Union[str, Callable, Any], default: Any = None) -> List:
    """Resolve edge attribute values from name, callable, or scalar.

    Parameters
    ----------
    G : networkx.Graph
        The graph.
    attr : str, callable, or scalar
        - str: Edge attribute name to look up
        - callable: Function (u, v, data) -> value
        - scalar: Single value for all edges
    default : any
        Default value if attribute not found.

    Returns
    -------
    list
        List of values for each edge in G.edges() order.
    """
    if attr is None:
        return [default] * len(G.edges())

    if callable(attr):
        return [attr(u, v, G.edges[u, v]) for u, v in G.edges()]

    if isinstance(attr, str):
        return [G.edges[u, v].get(attr, default) for u, v in G.edges()]

    # List/array pass-through (used for replay with pre-computed values)
    if isinstance(attr, (list, tuple, np.ndarray)):
        return list(attr)

    # Scalar value
    return [attr] * len(G.edges())


def _validate_graph(G):
    """Validate graph type and node IDs for serialization compatibility.

    Raises
    ------
    TypeError
        If the graph type is not supported or node IDs are not serializable.
    """
    import networkx as nx

    # Check for MultiGraph/MultiDiGraph
    if isinstance(G, (nx.MultiGraph, nx.MultiDiGraph)):
        raise TypeError(
            "MultiGraph and MultiDiGraph are not currently supported. "
            "Use Graph or DiGraph instead, or convert with: "
            "nx.Graph(G) or nx.DiGraph(G)"
        )

    # Check node ID types for serialization
    for node in G.nodes():
        if not isinstance(node, (str, int, float)):
            raise TypeError(
                f"Node ID {node!r} (type: {type(node).__name__}) is not serializable. "
                "Node IDs must be str, int, or float for recipe serialization. "
                "Consider converting node IDs to strings."
            )


def _normalize_sizes(sizes: List, min_size: float = 20, max_size: float = 300) -> List:
    """Normalize sizes to a reasonable range."""
    sizes = np.array(sizes, dtype=float)

    # Handle NaN/None
    valid_mask = ~np.isnan(sizes)
    if not valid_mask.any():
        return [min_size] * len(sizes)

    sizes_valid = sizes[valid_mask]
    if sizes_valid.min() == sizes_valid.max():
        return [min_size + (max_size - min_size) / 2] * len(sizes)

    # Normalize to [min_size, max_size]
    normalized = min_size + (sizes - sizes_valid.min()) / (
        sizes_valid.max() - sizes_valid.min()
    ) * (max_size - min_size)
    normalized = np.nan_to_num(normalized, nan=min_size)
    return normalized.tolist()


[docs] def draw_graph( ax: Axes, G, *, layout: str = "spring", pos: Optional[Dict] = None, seed: int = 42, # Node styling node_size: Union[str, Callable, float] = 100, node_color: Union[str, Callable, Any] = "#3498db", node_alpha: float = 0.8, node_shape: str = "o", node_edgecolors: str = "white", node_linewidths: float = 0.34, # 0.12mm # Edge styling (0.34pt = 0.12mm for SCITEX compatibility) edge_width: Union[str, Callable, float] = 0.34, edge_color: Union[str, Callable, Any] = "gray", edge_alpha: float = 0.5, edge_style: str = "solid", arrows: Optional[bool] = None, arrowsize: float = 10, arrowstyle: str = "-|>", connectionstyle: str = "arc3,rad=0.0", # Labels labels: Union[bool, Dict, str] = False, font_size: float = 6, font_color: str = "black", font_weight: str = "normal", font_family: str = "sans-serif", # Colormap for node_color when numeric colormap: str = "viridis", vmin: Optional[float] = None, vmax: Optional[float] = None, # Layout kwargs **layout_kwargs, ) -> Dict[str, Any]: """Draw a NetworkX graph on matplotlib axes. Parameters ---------- ax : matplotlib.axes.Axes The axes to draw on. G : networkx.Graph The graph to draw. layout : str Layout algorithm: 'spring', 'circular', 'kamada_kawai', 'shell', 'spectral', 'random', 'planar', 'spiral', 'hierarchical'. pos : dict, optional Pre-computed node positions {node: (x, y)}. seed : int Random seed for layout reproducibility. node_size : str, callable, or float Node sizes. Can be attribute name, callable (node, data) -> size, or scalar. node_color : str, callable, or any Node colors. Can be attribute name, callable, color name, or array. node_alpha : float Node transparency. node_shape : str Node marker shape. edge_width : str, callable, or float Edge widths. Can be attribute name, callable (u, v, data) -> width, or scalar. edge_color : str, callable, or any Edge colors. edge_alpha : float Edge transparency. arrows : bool, optional Draw arrows for directed graphs. Auto-detected if None. arrowsize : float Arrow head size for directed edges. labels : bool, dict, or str Node labels. True for node IDs, dict for custom labels, str for attribute name. font_size : float Label font size (default 6pt for scitex). colormap : str Matplotlib colormap for numeric node colors. **layout_kwargs Additional kwargs passed to layout algorithm. Returns ------- dict Dictionary with 'pos', 'node_collection', 'edge_collection'. """ import networkx as nx # Validate graph type _validate_graph(G) # Auto-detect arrows for directed graphs if arrows is None: arrows = G.is_directed() # Compute layout positions = _get_layout(G, layout, pos, seed, **layout_kwargs) # Resolve node sizes sizes = _resolve_node_attr(G, node_size, default=100) if isinstance(node_size, str) or callable(node_size): sizes = _normalize_sizes(sizes) # Resolve node colors colors = _resolve_node_attr(G, node_color, default="#3498db") # Check if colors are numeric (need colormap) color_array = None try: color_array = np.array(colors, dtype=float) if not np.isnan(color_array).all(): # Numeric colors - will use colormap pass else: color_array = None except (ValueError, TypeError): color_array = None # Resolve edge widths widths = _resolve_edge_attr(G, edge_width, default=1.0) if isinstance(edge_width, str) or callable(edge_width): widths = _normalize_sizes(widths, min_size=0.5, max_size=3.0) # Resolve edge colors edge_colors = _resolve_edge_attr(G, edge_color, default="gray") # Draw edges first (so nodes appear on top) edge_collection = None if G.number_of_edges() > 0: edge_kwargs = { "width": widths, "edge_color": edge_colors, "alpha": edge_alpha, "style": edge_style, "arrows": arrows, "ax": ax, } # Only add arrow-specific kwargs when arrows are enabled # (avoids UserWarning when using LineCollection for undirected graphs) if arrows: edge_kwargs["arrowsize"] = arrowsize edge_kwargs["arrowstyle"] = arrowstyle edge_kwargs["connectionstyle"] = connectionstyle edge_collection = nx.draw_networkx_edges(G, positions, **edge_kwargs) # Draw nodes if color_array is not None: node_collection = nx.draw_networkx_nodes( G, positions, node_size=sizes, node_color=color_array, alpha=node_alpha, node_shape=node_shape, edgecolors=node_edgecolors, linewidths=node_linewidths, cmap=colormap, vmin=vmin, vmax=vmax, ax=ax, ) else: node_collection = nx.draw_networkx_nodes( G, positions, node_size=sizes, node_color=colors, alpha=node_alpha, node_shape=node_shape, edgecolors=node_edgecolors, linewidths=node_linewidths, ax=ax, ) # Draw labels label_collection = None if labels: if labels is True: # Use node IDs as labels label_dict = {n: str(n) for n in G.nodes()} elif isinstance(labels, str): # Use node attribute as labels label_dict = {n: str(G.nodes[n].get(labels, n)) for n in G.nodes()} elif isinstance(labels, dict): label_dict = labels else: label_dict = {n: str(n) for n in G.nodes()} label_collection = nx.draw_networkx_labels( G, positions, labels=label_dict, font_size=font_size, font_color=font_color, font_weight=font_weight, font_family=font_family, ax=ax, ) # Remove axes frame for cleaner look ax.axis("off") return { "pos": positions, "node_collection": node_collection, "edge_collection": edge_collection, "label_collection": label_collection, }
def graph_to_record( G, pos: Optional[Dict] = None, **kwargs, ) -> Dict[str, Any]: """Convert a NetworkX graph to a serializable record. Parameters ---------- G : networkx.Graph The graph to serialize. pos : dict, optional Node positions to store. **kwargs Drawing parameters to store. Returns ------- dict Serializable record containing graph data and styling. """ # Validate graph type _validate_graph(G) nodes = [] for n in G.nodes(): node_data = dict(G.nodes[n]) node_data["id"] = n if pos and n in pos: node_data["x"] = float(pos[n][0]) node_data["y"] = float(pos[n][1]) nodes.append(node_data) edges = [] for u, v in G.edges(): edge_data = dict(G.edges[u, v]) edge_data["source"] = u edge_data["target"] = v edges.append(edge_data) record = { "type": "graph", "directed": G.is_directed(), "nodes": nodes, "edges": edges, "style": kwargs, } return record def record_to_graph(record: Dict[str, Any]): """Reconstruct a NetworkX graph from a serialized record. Parameters ---------- record : dict Record created by graph_to_record(). Returns ------- tuple (G, pos, style_kwargs) where G is the graph, pos is positions dict, and style_kwargs are the drawing parameters. Notes ----- This function does not modify the input record. """ import networkx as nx if record.get("directed", False): G = nx.DiGraph() else: G = nx.Graph() pos = {} for node_data in record.get("nodes", []): # Copy to avoid mutating input node_data = node_data.copy() node_id = node_data.pop("id") x = node_data.pop("x", None) y = node_data.pop("y", None) G.add_node(node_id, **node_data) if x is not None and y is not None: pos[node_id] = (x, y) for edge_data in record.get("edges", []): # Copy to avoid mutating input edge_data = edge_data.copy() source = edge_data.pop("source") target = edge_data.pop("target") G.add_edge(source, target, **edge_data) style = record.get("style", {}).copy() return G, pos if pos else None, style # EOF