Source code for figrecipe._diagram._shared._split

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: 2025-01-24
# File: figrecipe/_diagram/_split.py

"""
Auto-split large diagrams into multiple figures.

Strategies:
- by_groups: Split by existing layout.groups (deterministic, paper-friendly)
- by_articulation: Split at hub nodes (graph-theoretic)

The split preserves "ghost nodes" at boundaries for visual continuity.
"""

from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple

from ._schema import DiagramSpec


[docs] class SplitStrategy(Enum): BY_GROUPS = "by_groups" # Split by layout.groups BY_ARTICULATION = "by_articulation" # Split at hub nodes
[docs] @dataclass class SplitConfig: """Configuration for auto-splitting.""" enabled: bool = False max_nodes: int = 12 # Split if more nodes than this strategy: SplitStrategy = SplitStrategy.BY_GROUPS keep_hubs: bool = True # Show hub nodes in both parts ghost_style: str = "muted" # Style for ghost nodes
[docs] @dataclass class SplitResult: """Result of splitting a diagram.""" figures: List[DiagramSpec] labels: List[str] # Figure labels (A, B, C, ...) cut_nodes: Set[str] # Nodes that appear in multiple figures
def split_diagram( spec: DiagramSpec, config: Optional[SplitConfig] = None, group_assignments: Optional[List[List[str]]] = None, ) -> SplitResult: """ Split a diagram into multiple figures. Parameters ---------- spec : DiagramSpec Original diagram specification. config : SplitConfig, optional Split configuration. group_assignments : List[List[str]], optional Manual group assignments for splitting. If provided, overrides automatic detection. Returns ------- SplitResult List of split diagram specifications. """ if config is None: config = SplitConfig(enabled=True) # Check if split is needed if not config.enabled or len(spec.nodes) <= config.max_nodes: return SplitResult(figures=[spec], labels=[""], cut_nodes=set()) # Determine groups to split by if group_assignments: groups = group_assignments elif config.strategy == SplitStrategy.BY_GROUPS: groups = _split_by_groups(spec, max_nodes=config.max_nodes) else: # BY_ARTICULATION groups = _split_by_articulation(spec) # Create split figures figures = [] labels = [] cut_nodes = set() for i, group_nodes in enumerate(groups): fig, cuts = _create_split_figure(spec, group_nodes, config) figures.append(fig) labels.append(chr(ord("A") + i)) cut_nodes.update(cuts) return SplitResult(figures=figures, labels=labels, cut_nodes=cut_nodes) def _split_by_groups(spec: DiagramSpec, max_nodes: int = 12) -> List[List[str]]: """ Split by existing layout.groups using greedy packing. Packs groups into figures until max_nodes is exceeded, then starts a new figure. Returns list of node ID lists, one per split figure. """ if not spec.layout.groups: # No groups defined - try to split in half node_ids = [n.id for n in spec.nodes] mid = len(node_ids) // 2 return [node_ids[:mid], node_ids[mid:]] # Group keys in order group_names = list(spec.layout.groups.keys()) # Greedy packing: add groups until max_nodes exceeded figures = [] current_figure = [] current_count = 0 for group_name in group_names: group_nodes = spec.layout.groups[group_name] group_size = len(group_nodes) # If adding this group exceeds max and we have something, start new figure if current_count + group_size > max_nodes and current_figure: figures.append(current_figure) current_figure = [] current_count = 0 # Add group to current figure current_figure.extend(group_nodes) current_count += group_size # Don't forget the last figure if current_figure: figures.append(current_figure) # Add ungrouped nodes to first figure grouped = set() for fig in figures: grouped.update(fig) for n in spec.nodes: if n.id not in grouped: if figures: figures[0].append(n.id) else: figures.append([n.id]) # Ensure at least 2 figures if we have enough nodes if len(figures) == 1 and len(figures[0]) > max_nodes: # Force split in half nodes = figures[0] mid = len(nodes) // 2 figures = [nodes[:mid], nodes[mid:]] return figures def _split_by_articulation(spec: DiagramSpec) -> List[List[str]]: """ Split at articulation points (hub nodes). This finds nodes that, if removed, would disconnect the graph. These are natural split points for large diagrams. """ # Build adjacency adj: Dict[str, Set[str]] = {n.id: set() for n in spec.nodes} for e in spec.edges: adj[e.source].add(e.target) adj[e.target].add(e.source) # Find node with most connections (hub) hub = max(adj.keys(), key=lambda x: len(adj[x])) # BFS from first node, stopping at hub visited = {hub} # Block the hub node_ids = [n.id for n in spec.nodes if n.id != hub] if not node_ids: return [[hub]] # Find components when hub is removed components = [] for start in node_ids: if start in visited: continue component = [] queue = [start] while queue: curr = queue.pop(0) if curr in visited: continue visited.add(curr) component.append(curr) for neighbor in adj[curr]: if neighbor not in visited: queue.append(neighbor) if component: components.append(component) # Add hub to each component (as ghost) for comp in components: comp.append(hub) return components if components else [[n.id for n in spec.nodes]] def _create_split_figure( spec: DiagramSpec, node_ids: List[str], config: SplitConfig, ) -> Tuple[DiagramSpec, Set[str]]: """ Create a split figure containing specified nodes. Returns (figure_spec, ghost_node_ids). """ node_id_set = set(node_ids) # Find edges that cross the boundary boundary_nodes = set() for edge in spec.edges: src_in = edge.source in node_id_set tgt_in = edge.target in node_id_set if src_in and not tgt_in: if config.keep_hubs: boundary_nodes.add(edge.target) elif tgt_in and not src_in: if config.keep_hubs: boundary_nodes.add(edge.source) # Create new spec new_spec = DiagramSpec( type=spec.type, title=spec.title, paper=deepcopy(spec.paper), layout=deepcopy(spec.layout), theme=dict(spec.theme), ) # Filter nodes node_map = {n.id: n for n in spec.nodes} for node_id in node_ids: if node_id in node_map: new_spec.nodes.append(deepcopy(node_map[node_id])) # Add ghost nodes (boundary nodes not in this split) for ghost_id in boundary_nodes: if ghost_id in node_map and ghost_id not in node_id_set: ghost = deepcopy(node_map[ghost_id]) ghost.emphasis = config.ghost_style ghost.label = f"→ {ghost.label}" # Mark as continuation new_spec.nodes.append(ghost) # Filter edges all_ids = node_id_set | boundary_nodes for edge in spec.edges: if edge.source in all_ids and edge.target in all_ids: new_spec.edges.append(deepcopy(edge)) # Filter groups new_spec.layout.groups = {} for group_name, group_nodes in spec.layout.groups.items(): filtered = [n for n in group_nodes if n in all_ids] if filtered: new_spec.layout.groups[group_name] = filtered # Filter layers new_spec.layout.layers = [] for layer in spec.layout.layers: filtered = [n for n in layer if n in all_ids] if filtered: new_spec.layout.layers.append(filtered) return new_spec, boundary_nodes