# -*- coding: utf-8 -*-
"""
centrality_disruption.py
This module implements the Centrality Disruption Index (CDI) metric for
measuring how perturbations affect the centrality distribution of nodes
in a hypergraph.
"""
from typing import TYPE_CHECKING, Optional, Dict, Callable
import numpy as np
from scipy import stats
if TYPE_CHECKING:
from hiper.core import Hypernetwork
[docs]
class CentralityDisruptionIndex:
"""
Computes the Centrality Disruption Index (CDI) metric.
The Centrality Disruption Index measures how much a perturbation affects
the centrality distribution of nodes using the Kolmogorov-Smirnov test:
:math:`\\text{CDI} = \\text{KS}(\\{C^{\\text{pre}}_v\\}, \\{C^{\\text{post}}_v\\})`
where:
- :math:`C^{\\text{pre}}_v` are centrality values before perturbation
- :math:`C^{\\text{post}}_v` are centrality values after perturbation
- KS is the Kolmogorov-Smirnov distance between distributions
The metric returns a value in [0, 1] where:
- 0.0 indicates no disruption (identical distributions)
- 1.0 indicates maximum disruption (completely different distributions)
"""
[docs]
def __init__(self, centrality_type: str = 'degree'):
"""
Initialize the Centrality Disruption Index metric.
Args:
centrality_type: Type of centrality to use ('degree', 'closeness',
'betweenness', or custom function).
"""
self.name = "Centrality Disruption Index"
self.symbol = "CDI"
self.centrality_type = centrality_type
[docs]
def compute(self,
before_hypergraph: 'Hypernetwork',
after_hypergraph: 'Hypernetwork',
centrality_function: Optional[Callable] = None) -> float:
"""
Compute the Centrality Disruption Index metric.
Args:
before_hypergraph: The hypergraph before perturbation.
after_hypergraph: The hypergraph after perturbation.
centrality_function: Custom centrality function. If None, uses
built-in centrality based on centrality_type.
Returns:
The centrality disruption index value in [0, 1].
"""
# Compute centralities before and after
centralities_before = self._compute_centralities(
before_hypergraph, centrality_function)
centralities_after = self._compute_centralities(
after_hypergraph, centrality_function)
# Get common nodes for fair comparison
common_nodes = set(centralities_before.keys()).intersection(
set(centralities_after.keys()))
if len(common_nodes) < 2:
# Not enough nodes for meaningful comparison
return 1.0 if len(centralities_before) != len(
centralities_after) else 0.0
# Extract centrality values for common nodes
values_before = [centralities_before[node] for node in common_nodes]
values_after = [centralities_after[node] for node in common_nodes]
# Handle edge cases
if len(values_before) == 0 or len(values_after) == 0:
return 1.0
# Compute Kolmogorov-Smirnov statistic
try:
ks_statistic, _ = stats.ks_2samp(values_before, values_after)
return min(1.0, max(0.0, ks_statistic))
except (ValueError, RuntimeError):
# Fallback to normalized absolute difference of means
mean_before = np.mean(values_before)
mean_after = np.mean(values_after)
max_val = max(max(values_before), max(values_after))
if max_val == 0:
return 0.0
return min(1.0, abs(mean_before - mean_after) / max_val)
def _compute_centralities(self,
hypergraph: 'Hypernetwork',
centrality_function: Optional[Callable] = None
) -> Dict[int, float]:
"""
Compute centrality values for all nodes in the hypergraph.
Args:
hypergraph: The hypergraph to analyze.
centrality_function: Custom centrality function.
Returns:
Dictionary mapping node IDs to centrality values.
"""
if centrality_function is not None:
return centrality_function(hypergraph)
if self.centrality_type == 'degree':
return self._compute_degree_centrality(hypergraph)
elif self.centrality_type == 'closeness':
return self._compute_closeness_centrality(hypergraph)
elif self.centrality_type == 'betweenness':
return self._compute_betweenness_centrality(hypergraph)
else:
# Default to degree centrality
return self._compute_degree_centrality(hypergraph)
@staticmethod
def _compute_degree_centrality(hypergraph: 'Hypernetwork') -> Dict[
int, float]:
"""Compute degree centrality for all nodes."""
centralities = {}
for node in hypergraph.nodes.keys():
# Count how many hyperedges contain this node
degree = len(hypergraph.get_hyperedges(node))
centralities[node] = float(degree)
return centralities
def _compute_closeness_centrality(self, hypergraph: 'Hypernetwork'
) -> Dict[int, float]:
"""Compute closeness centrality based on hypergraph distances."""
centralities = {}
nodes = list(hypergraph.nodes.keys())
if len(nodes) <= 1:
return {node: 1.0 for node in nodes}
# Build adjacency based on hyperedge co-membership
adjacency = self._build_node_adjacency(hypergraph)
for node in nodes:
distances = self._compute_shortest_distances(adjacency, node)
if distances:
total_distance = sum(distances.values())
if total_distance > 0:
centralities[node] = (len(distances) - 1) / total_distance
else:
centralities[node] = 1.0
else:
centralities[node] = 0.0
return centralities
@staticmethod
def _compute_betweenness_centrality(hypergraph: 'Hypernetwork'
) -> Dict[int, float]:
"""
Compute betweenness centrality.
"""
centralities = {}
nodes = list(hypergraph.nodes.keys())
for node in nodes:
centralities[node] = 0.0
# Betweenness based on hyperedge participation
for node in nodes:
node_hyperedges = hypergraph.get_hyperedges(node)
# Count how many node pairs this node connects
connected_pairs = 0
for he_id in node_hyperedges:
he_nodes = hypergraph.get_nodes(he_id)
other_nodes = [n for n in he_nodes if n != node]
pairs_count = len(other_nodes) * (len(other_nodes) - 1) // 2
connected_pairs += pairs_count
centralities[node] = float(connected_pairs)
return centralities
@staticmethod
def _build_node_adjacency(hypergraph: 'Hypernetwork') -> Dict[int, set]:
"""Build node adjacency graph based on hyperedge co-membership."""
adjacency = {node: set() for node in hypergraph.nodes.keys()}
for he_id in hypergraph.edges.keys():
he_nodes = hypergraph.get_nodes(he_id)
# Connect all pairs in the hyperedge
for i, node1 in enumerate(he_nodes):
for node2 in he_nodes[i + 1:]:
adjacency[node1].add(node2)
adjacency[node2].add(node1)
return adjacency
@staticmethod
def _compute_shortest_distances(adjacency: Dict[int, set],
source: int) -> Dict[int, int]:
"""Compute the shortest distances from source using BFS."""
distances = {source: 0}
queue = [source]
visited = {source}
while queue:
current = queue.pop(0)
current_dist = distances[current]
for neighbor in adjacency.get(current, set()):
if neighbor not in visited:
visited.add(neighbor)
distances[neighbor] = current_dist + 1
queue.append(neighbor)
return distances
[docs]
def compute_detailed(self,
before_hypergraph: 'Hypernetwork',
after_hypergraph: 'Hypernetwork',
centrality_function: Optional[Callable] = None
) -> dict:
"""
Compute detailed centrality disruption analysis.
Args:
before_hypergraph: The hypergraph before perturbation.
after_hypergraph: The hypergraph after perturbation.
centrality_function: Custom centrality function.
Returns:
Dictionary containing detailed analysis results.
"""
centralities_before = self._compute_centralities(
before_hypergraph, centrality_function)
centralities_after = self._compute_centralities(
after_hypergraph, centrality_function)
common_nodes = set(centralities_before.keys()).intersection(
set(centralities_after.keys()))
values_before = [centralities_before[node] for node in common_nodes]
values_after = [centralities_after[node] for node in common_nodes]
ks_statistic = 0.0
p_value = 1.0
if len(values_before) >= 2 and len(values_after) >= 2:
try:
ks_statistic, p_value = stats.ks_2samp(values_before,
values_after)
except (ValueError, RuntimeError):
pass
return {
'ks_statistic': ks_statistic,
'p_value': p_value,
'nodes_before': len(centralities_before),
'nodes_after': len(centralities_after),
'common_nodes': len(common_nodes),
'centralities_before': centralities_before,
'centralities_after': centralities_after,
'mean_before': np.mean(values_before) if values_before else 0.0,
'mean_after': np.mean(values_after) if values_after else 0.0,
'std_before': np.std(values_before) if values_before else 0.0,
'std_after': np.std(values_after) if values_after else 0.0
}
[docs]
def __str__(self) -> str:
"""String representation of the metric."""
return (f"{self.name} ({self.symbol}) - "
f"{self.centrality_type} centrality")
[docs]
def __repr__(self) -> str:
"""Detailed string representation."""
return (f"CentralityDisruptionIndex(name='{self.name}', "
f"symbol='{self.symbol}', "
f"centrality_type='{self.centrality_type}')")