Source code for scitex_scholar.citation_graph.builder

"""
Citation Graph Builder

Main interface for building citation networks from CrossRef data.
"""

import json
from pathlib import Path
from typing import List, Optional

from .database import CitationDatabase
from .models import CitationEdge, CitationGraph, PaperNode


[docs] class CitationGraphBuilder: """ Build citation network graphs for academic papers. Auto-detects backend via crossref_local.Config (DB → HTTP). Example (auto-detect): >>> builder = CitationGraphBuilder() >>> graph = builder.build("10.1038/s41586-020-2008-3", top_n=20) Example (explicit SQLite): >>> builder = CitationGraphBuilder(db_path="/path/to/crossref.db") Example (explicit HTTP): >>> builder = CitationGraphBuilder(api_url="http://localhost:31291") """
[docs] def __init__(self, db_path: str = None, api_url: str = None): """ Initialize builder with database path, HTTP API URL, or auto-detect. When no args given, delegates to crossref_local.Config for auto-detection: 1. CROSSREF_LOCAL_MODE env var (explicit "db" or "http") 2. CROSSREF_LOCAL_API_URL env var → HTTP mode 3. Local DB file existence → DB mode 4. Fallback to HTTP mode Args: db_path: Path to CrossRef SQLite database (local mode) api_url: URL of crossref-local HTTP API (HTTP mode) """ if api_url: from .database_http import CitationDatabaseHTTP self.db_path = None self.db = CitationDatabaseHTTP(api_url) elif db_path: self.db_path = db_path self.db = CitationDatabase(db_path) else: self._auto_detect()
def _auto_detect(self): """Auto-detect backend via crossref_local.Config.""" from crossref_local._core.config import Config mode = Config.get_mode() if mode == "db": self.db_path = str(Config.get_db_path()) self.db = CitationDatabase(self.db_path) else: from .database_http import CitationDatabaseHTTP self.db_path = None self.db = CitationDatabaseHTTP(Config.get_api_url())
[docs] def build( self, seed_doi: str, top_n: int = 20, weight_coupling: float = 2.0, weight_cocitation: float = 2.0, weight_direct: float = 1.0, ) -> CitationGraph: """ Build citation network around a seed paper. Args: seed_doi: DOI of the seed paper top_n: Number of most similar papers to include weight_coupling: Weight for bibliographic coupling weight_cocitation: Weight for co-citation weight_direct: Weight for direct citations Returns ------- CitationGraph object with nodes and edges """ with self.db: # Calculate similarity scores scores = self.db.get_combined_similarity_scores( seed_doi, weight_coupling=weight_coupling, weight_cocitation=weight_cocitation, weight_direct=weight_direct, ) # Get top N most similar papers top_dois = [seed_doi] + [doi for doi, _ in scores.most_common(top_n)] # Build nodes with metadata nodes = [] for doi in top_dois: node = self._create_paper_node(doi, scores.get(doi, 100.0)) if doi == seed_doi: node.is_seed = True nodes.append(node) # Build edges (citations between papers in network) edges = self._build_citation_edges(top_dois) # Create graph graph = CitationGraph( seed_doi=seed_doi, seed_dois=[seed_doi], nodes=nodes, edges=edges, metadata={ "top_n": top_n, "weights": { "coupling": weight_coupling, "cocitation": weight_cocitation, "direct": weight_direct, }, }, ) return graph
def _create_paper_node(self, doi: str, similarity_score: float) -> PaperNode: """ Create a PaperNode with metadata from database. Args: doi: DOI of the paper similarity_score: Calculated similarity score Returns ------- PaperNode object """ metadata = self.db.get_paper_metadata(doi) if metadata: # Extract author names authors = metadata.get("author", []) author_names = [ f"{a.get('family', '')} {a.get('given', '')[:1]}" for a in authors[:3] ] # Extract year year = 0 if "published" in metadata and "date-parts" in metadata["published"]: date_parts = metadata["published"]["date-parts"] if date_parts and date_parts[0]: year = date_parts[0][0] if date_parts[0][0] else 0 # Extract journal journal = "" if "container-title" in metadata and metadata["container-title"]: journal = metadata["container-title"][0] return PaperNode( doi=doi, title=metadata.get("title", ["Unknown"])[0][:200], year=year, authors=author_names, journal=journal, similarity_score=similarity_score, ) else: return PaperNode(doi=doi, similarity_score=similarity_score) def _build_citation_edges(self, dois: List[str]) -> List[CitationEdge]: """ Build citation edges between papers in the network. Args: dois: List of DOIs in the network Returns ------- List of CitationEdge objects """ edges = [] doi_set = set(d.lower() for d in dois) for doi in dois: # Get references (papers this one cites) refs = self.db.get_references(doi, limit=100) for cited_doi in refs: if cited_doi in doi_set: edges.append( CitationEdge( source=doi, target=cited_doi, edge_type="cites", ) ) return edges
[docs] def build_from_dois( self, dois: List[str], num_related_per_doi: int = 20, weight_coupling: float = 2.0, weight_cocitation: float = 2.0, weight_direct: float = 1.0, ) -> CitationGraph: """ Build citation network from multiple seed DOIs. Combines similarity scores from all seeds to find papers related to the entire set, producing a richer connected graph. Args: dois: List of seed DOIs num_related_per_doi: Number of related papers to discover per DOI weight_coupling: Weight for bibliographic coupling weight_cocitation: Weight for co-citation weight_direct: Weight for direct citations Returns ------- CitationGraph with all seeds + related papers + edges """ with self.db: seed_set = set(d.lower() for d in dois) # Batch query: 4 SQL queries total regardless of DOI count combined_scores = self.db.get_combined_similarity_scores_batch( dois, weight_coupling=weight_coupling, weight_cocitation=weight_cocitation, weight_direct=weight_direct, ) # Top N related papers (scaled by number of seeds) top_count = num_related_per_doi * len(dois) related_dois = [doi for doi, _ in combined_scores.most_common(top_count)] # All DOIs = seeds + related all_dois = list(dois) + related_dois # Build nodes nodes = [] for doi in all_dois: score = combined_scores.get(doi, 100.0) node = self._create_paper_node(doi, score) if doi.lower() in seed_set: node.is_seed = True nodes.append(node) # Build edges edges = self._build_citation_edges(all_dois) return CitationGraph( seed_doi=dois[0], seed_dois=list(dois), nodes=nodes, edges=edges, metadata={ "num_related_per_doi": num_related_per_doi, "num_seeds": len(dois), "weights": { "coupling": weight_coupling, "cocitation": weight_cocitation, "direct": weight_direct, }, }, )
[docs] def build_from_query( self, query: str, num_related_per_doi: int = 20, search_limit: int = 10, weight_coupling: float = 2.0, weight_cocitation: float = 2.0, weight_direct: float = 1.0, ) -> CitationGraph: """ Build citation network from a text query. Searches local databases, extracts DOIs from results, then delegates to build_from_dois(). Args: query: Search query (e.g. "hippocampal sharp wave ripples") num_related_per_doi: Related papers per seed DOI search_limit: Max papers to fetch from search weight_coupling: Weight for bibliographic coupling weight_cocitation: Weight for co-citation weight_direct: Weight for direct citations Returns ------- CitationGraph with search-discovered seeds + related papers """ from ..local_dbs.unified import search results = search(query, limit=search_limit) dois = [w.doi for w in results.works if w.doi and w.doi.strip()] if not dois: return CitationGraph( seed_doi="", seed_dois=[], nodes=[], edges=[], metadata={"query": query, "error": "No papers with DOI found"}, ) graph = self.build_from_dois( dois=dois, num_related_per_doi=num_related_per_doi, weight_coupling=weight_coupling, weight_cocitation=weight_cocitation, weight_direct=weight_direct, ) graph.metadata["query"] = query graph.metadata["search_results_count"] = len(results.works) return graph
[docs] def export_json(self, graph: CitationGraph, output_path: str): """ Export graph to JSON file for visualization. Args: graph: CitationGraph to export output_path: Path to output JSON file """ output = Path(output_path) with open(output, "w") as f: json.dump(graph.to_dict(), f, indent=2)
[docs] def get_paper_summary(self, doi: str) -> Optional[dict]: """ Get summary information for a paper. Args: doi: DOI of the paper Returns ------- Dictionary with paper summary """ with self.db: metadata = self.db.get_paper_metadata(doi) if not metadata: return None # Get citation counts refs = self.db.get_references(doi, limit=1000) citations = self.db.get_citations(doi, limit=1000) return { "doi": doi, "title": metadata.get("title", ["Unknown"])[0], "year": metadata.get("published", {}).get("date-parts", [[0]])[0][0], "authors": [ f"{a.get('family', '')} {a.get('given', '')}" for a in metadata.get("author", [])[:5] ], "journal": metadata.get("container-title", ["Unknown"])[0], "reference_count": len(refs), "citation_count": len(citations), }