Source code for scitex_introspect._call_graph

#!/usr/bin/env python3
# Timestamp: 2025-01-20
# File: /home/ywatanabe/proj/scitex-code/src/scitex/introspect/_call_graph.py

"""Call graph analysis using AST with timeout protection."""

from __future__ import annotations

import ast
import inspect
import signal
from contextlib import contextmanager
from pathlib import Path

from ._resolve import get_type_info, resolve_object


class TimeoutError(Exception):
    """Raised when operation times out."""

    pass


@contextmanager
def timeout(seconds: int):
    """Context manager for timeout (Unix only)."""

    def handler(signum, frame):
        raise TimeoutError(f"Operation timed out after {seconds}s")

    # Only works on Unix
    try:
        old_handler = signal.signal(signal.SIGALRM, handler)
        signal.alarm(seconds)
        try:
            yield
        finally:
            signal.alarm(0)
            signal.signal(signal.SIGALRM, old_handler)
    except (ValueError, AttributeError):
        # Windows or signal not available - no timeout
        yield


[docs] def get_call_graph( dotted_path: str, max_depth: int = 2, timeout_seconds: int = 10, internal_only: bool = True, ) -> dict: """ Get the call graph of a function or module using static AST analysis. Parameters ---------- dotted_path : str Dotted path to the function or module max_depth : int Maximum depth to traverse calls timeout_seconds : int Timeout in seconds (0 = no timeout) internal_only : bool Only show calls to functions in the same module Returns ------- dict calls: list[dict] - Functions this function calls called_by: list[dict] - Functions that call this (if module) graph: dict - Full call graph tree Examples -------- >>> get_call_graph("scitex.audio.speak") """ try: if timeout_seconds > 0: with timeout(timeout_seconds): return _analyze_call_graph(dotted_path, max_depth, internal_only) else: return _analyze_call_graph(dotted_path, max_depth, internal_only) except TimeoutError as e: return { "success": False, "error": str(e), "partial": True, }
def _analyze_call_graph( dotted_path: str, max_depth: int, internal_only: bool, ) -> dict: """Perform the actual call graph analysis.""" obj, error = resolve_object(dotted_path) if error: return {"success": False, "error": error} type_info = get_type_info(obj) # Get source file try: source_file = inspect.getfile(obj) source = Path(source_file).read_text() tree = ast.parse(source) except Exception as e: return { "success": False, "error": f"Cannot parse source: {e}", "type_info": type_info, } # Build function index for the module func_index = _build_function_index(tree) if inspect.isfunction(obj): # Analyze single function func_name = obj.__name__ if func_name not in func_index: return { "success": False, "error": f"Function '{func_name}' not found in source", "type_info": type_info, } calls = _get_function_calls(func_index[func_name], internal_only, func_index) called_by = _find_callers(func_name, func_index) return { "success": True, "function": func_name, "calls": calls, "call_count": len(calls), "called_by": called_by, "caller_count": len(called_by), "type_info": type_info, } elif inspect.ismodule(obj): # Analyze entire module graph = {} for func_name, func_node in func_index.items(): calls = _get_function_calls(func_node, internal_only, func_index) graph[func_name] = { "calls": calls, "line": func_node.lineno, } return { "success": True, "module": dotted_path, "graph": graph, "function_count": len(graph), "type_info": type_info, } else: return { "success": False, "error": "Can only analyze functions or modules", "type_info": type_info, } def _build_function_index(tree: ast.AST) -> dict[str, ast.FunctionDef]: """Build index of all functions in the AST.""" index = {} for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): index[node.name] = node return index def _get_function_calls( func_node: ast.FunctionDef, internal_only: bool, func_index: dict, ) -> list[dict]: """Extract all function calls from a function.""" calls = [] seen = set() for node in ast.walk(func_node): if isinstance(node, ast.Call): call_info = _extract_call_info(node) if call_info and call_info["name"] not in seen: # Filter to internal only if requested if internal_only and call_info["name"] not in func_index: continue seen.add(call_info["name"]) calls.append(call_info) return calls def _extract_call_info(node: ast.Call) -> dict | None: """Extract information about a function call.""" func = node.func if isinstance(func, ast.Name): # Simple call: func() return { "name": func.id, "type": "function", "line": node.lineno, } elif isinstance(func, ast.Attribute): # Method call: obj.method() if isinstance(func.value, ast.Name): return { "name": f"{func.value.id}.{func.attr}", "type": "method", "object": func.value.id, "method": func.attr, "line": node.lineno, } else: return { "name": func.attr, "type": "method", "method": func.attr, "line": node.lineno, } return None def _find_callers( func_name: str, func_index: dict[str, ast.FunctionDef], ) -> list[dict]: """Find all functions that call the given function.""" callers = [] for caller_name, caller_node in func_index.items(): if caller_name == func_name: continue for node in ast.walk(caller_node): if isinstance(node, ast.Call): call_info = _extract_call_info(node) if call_info and call_info["name"] == func_name: callers.append( { "name": caller_name, "line": caller_node.lineno, } ) break return callers
[docs] def get_function_calls( dotted_path: str, include_methods: bool = True, include_builtins: bool = False, ) -> dict: """ Get just the outgoing calls from a function. Simpler version of get_call_graph for quick lookup. Parameters ---------- dotted_path : str Dotted path to the function include_methods : bool Include method calls (obj.method()) include_builtins : bool Include builtin function calls Returns ------- dict calls: list[str] - Names of called functions """ result = get_call_graph(dotted_path, max_depth=1, internal_only=False) if not result.get("success"): return result calls = result.get("calls", []) # Filter filtered = [] builtins = {"print", "len", "range", "str", "int", "float", "list", "dict", "set"} for call in calls: name = call["name"] if not include_methods and call.get("type") == "method": continue if not include_builtins and name in builtins: continue filtered.append(name) return { "success": True, "function": dotted_path, "calls": filtered, "call_count": len(filtered), }