diff --git a/networkx/algorithms/tree/branchings.py b/networkx/algorithms/tree/branchings.py
index 65c03feef..02c1a136f 100644
--- a/networkx/algorithms/tree/branchings.py
+++ b/networkx/algorithms/tree/branchings.py
@@ -50,7 +50,7 @@ def branching_weight(G, attr='weight', default=1):
11
"""
- pass
+ return sum(edge.get(attr, default) for u, v, edge in G.edges(data=True))
@py_random_state(4)
@nx._dispatchable(edge_attrs={'attr': 'default'}, returns_graph=True)
@@ -87,8 +87,8 @@ def greedy_branching(G, attr='weight', default=1, kind='max', seed=None):
The greedily obtained branching.
"""
- pass
-
+ if kind not in ['min', 'max']:
+ raise nx.NetworkXException("Unknown value for `kind`. Must be 'min' or 'max'.")
class MultiDiGraph_EdgeKey(nx.MultiDiGraph):
"""
MultiDiGraph which assigns unique keys to every edge.
@@ -116,10 +116,66 @@ class MultiDiGraph_EdgeKey(nx.MultiDiGraph):
def add_edge(self, u_for_edge, v_for_edge, key_for_edge, **attr):
"""
- Key is now required.
+ Add an edge to the graph.
+ Parameters
+ ----------
+ u_for_edge : node
+ Source node.
+ v_for_edge : node
+ Target node.
+ key_for_edge : hashable
+ Unique identifier for the edge.
+ attr : keyword arguments, optional
+ Edge data (or labels or objects) can be assigned using
+ keyword arguments.
+
+ Returns
+ -------
+ The edge key assigned to the edge.
"""
- pass
+ # Generate a new unique key if one isn't provided
+ if key_for_edge is None:
+ key_for_edge = len(self.edge_index)
+
+ # Add the edge to the graph
+ self._cls.add_edge(self, u_for_edge, v_for_edge, key=key_for_edge, **attr)
+
+ # Add the edge to our index
+ self.edge_index[key_for_edge] = (u_for_edge, v_for_edge, attr)
+
+ return key_for_edge
+ """
+ MultiDiGraph which assigns unique keys to every edge.
+
+ Adds a dictionary edge_index which maps edge keys to (u, v, data) tuples.
+
+ This is not a complete implementation. For Edmonds algorithm, we only use
+ add_node and add_edge, so that is all that is implemented here. During
+ additions, any specified keys are ignored---this means that you also
+ cannot update edge attributes through add_node and add_edge.
+
+ Why do we need this? Edmonds algorithm requires that we track edges, even
+ as we change the head and tail of an edge, and even changing the weight
+ of edges. We must reliably track edges across graph mutations.
+ """
+
+ def __init__(self, incoming_graph_data=None, **attr):
+ cls = super()
+ cls.__init__(incoming_graph_data=incoming_graph_data, **attr)
+ self._cls = cls
+ self.edge_index = {}
+ import warnings
+ msg = 'MultiDiGraph_EdgeKey has been deprecated and will be removed in NetworkX 3.4.'
+ warnings.warn(msg, DeprecationWarning)
+
+ def add_edge(self, u_for_edge, v_for_edge, key_for_edge, **attr):
+ """
+ Key is now required.
+ """
+ if key_for_edge is None:
+ raise ValueError("A key is required")
+ return super().add_edge(u_for_edge, v_for_edge, key=key_for_edge, **attr)
def get_path(G, u, v):
"""
@@ -129,7 +185,8 @@ def get_path(G, u, v):
MultiDiGraph_EdgeKey.
"""
- pass
+ path = nx.shortest_path(G, u, v)
+ return [G[u][v][0]['key'] for u, v in zip(path[:-1], path[1:])]
class Edmonds:
"""
@@ -164,19 +221,57 @@ class Edmonds:
def _init(self, attr, default, kind, style, preserve_attrs, seed, partition):
"""
- So we need the code in _init and find_optimum to successfully run edmonds algorithm.
- Responsibilities of the _init function:
- - Check that the kind argument is in {min, max} or raise a NetworkXException.
- - Transform the graph if we need a minimum arborescence/branching.
- - The current method is to map weight -> -weight. This is NOT a good approach since
- the algorithm can and does choose to ignore negative weights when creating a branching
- since that is always optimal when maximzing the weights. I think we should set the edge
- weights to be (max_weight + 1) - edge_weight.
- - Transform the graph into a MultiDiGraph, adding the partition information and potoentially
- other edge attributes if we set preserve_attrs = True.
- - Setup the buckets and union find data structures required for the algorithm.
+ Initialize the algorithm with the given parameters.
+
+ This method sets up the necessary data structures and performs initial checks.
"""
- pass
+ from networkx.utils import UnionFind
+ from enum import Enum
+
+ class EdgePartition(Enum):
+ OPEN = 0
+ INCLUDED = 1
+ EXCLUDED = 2
+
+ if kind not in ('min', 'max'):
+ raise nx.NetworkXException("Unknown value for `kind`. Must be 'min' or 'max'.")
+
+ # Create a new graph with the correct structure
+ self.G = MultiDiGraph_EdgeKey()
+
+ # Transform the graph if we need a minimum arborescence/branching
+ if kind == 'min':
+ max_weight = max(d.get(attr, default) for u, v, d in self.G_original.edges(data=True))
+ for u, v, d in self.G_original.edges(data=True):
+ weight = d.get(attr, default)
+ new_weight = (max_weight + 1) - weight
+ edge_data = {attr: new_weight}
+ if preserve_attrs:
+ edge_data.update((k, v) for k, v in d.items() if k != attr)
+ if partition is not None:
+ edge_data[partition] = d.get(partition, EdgePartition.OPEN)
+ self.G.add_edge(u, v, self.template.format(len(self.edges)), **edge_data)
+ self.edges.append((u, v, d.get(attr, default), d.get(partition, EdgePartition.OPEN)))
+ else: # kind == 'max'
+ for u, v, d in self.G_original.edges(data=True):
+ edge_data = {attr: d.get(attr, default)}
+ if preserve_attrs:
+ edge_data.update((k, v) for k, v in d.items() if k != attr)
+ if partition is not None:
+ edge_data[partition] = d.get(partition, EdgePartition.OPEN)
+ self.G.add_edge(u, v, self.template.format(len(self.edges)), **edge_data)
+ self.edges.append((u, v, d.get(attr, default), d.get(partition, EdgePartition.OPEN)))
+
+ # Setup the buckets for the algorithm
+ self.buckets = {u: self.G.in_degree(u) for u in self.G}
+
+ # Setup the union-find data structure
+ self.uf = UnionFind(self.G.nodes())
+
+ self.attr = attr
+ self.default = default
+ self.kind = kind
+ self.style = style
def find_optimum(self, attr='weight', default=1, kind='max', style='branching', preserve_attrs=False, partition=None, seed=None):
"""
@@ -213,7 +308,107 @@ class Edmonds:
The branching.
"""
- pass
+ self._init(attr, default, kind, style, preserve_attrs, seed, partition)
+
+ # Main loop of the algorithm
+ while len(self.G) > 1:
+ # Find the maximum edge entering each node
+ enters = self._find_maximum_edges()
+
+ # Contract cycles if there are any
+ if self._contract_cycles(enters):
+ continue
+
+ # Merge trees
+ self._merge_trees(enters)
+
+ # Reconstruct the branching
+ return self._reconstruct_branching()
+
+ def _find_maximum_edges(self):
+ """
+ Find the maximum weight edge entering each node.
+
+ Returns
+ -------
+ dict
+ A dictionary keyed by node with the maximum weight edge entering that node.
+ """
+ enters = {}
+ for v in self.G:
+ edges = self.G.in_edges(v, data=True)
+ if edges:
+ enters[v] = max(edges, key=lambda e: e[2].get(self.attr, self.default))
+ return enters
+ def _contract_cycles(self, enters):
+ """
+ Contract cycles in the graph.
+
+ Parameters
+ ----------
+ enters : dict
+ A dictionary of entering edges for each node.
+
+ Returns
+ -------
+ bool
+ True if a cycle was contracted, False otherwise.
+ """
+ G_cycles = nx.DiGraph()
+ G_cycles.add_edges_from((v, enters[v][0]) for v in enters)
+ cycles = list(nx.simple_cycles(G_cycles))
+
+ if not cycles:
+ return False
+
+ # Contract the first cycle found
+ cycle = cycles[0]
+ cycle_attr = {self.attr: sum(self.G[u][v][0][self.attr] for u, v in zip(cycle, cycle[1:] + cycle[:1]))}
+ self.G = nx.contracted_nodes(self.G, cycle[0], cycle[1], self_loops=False)
+ for node in cycle[2:]:
+ self.G = nx.contracted_nodes(self.G, cycle[0], node, self_loops=False)
+
+ # Update the edge attributes of the contracted node
+ for _, _, d in self.G.in_edges(cycle[0], data=True):
+ d[self.attr] = d.get(self.attr, self.default) - cycle_attr[self.attr]
+
+ return True
+
+ def _merge_trees(self, enters):
+ """
+ Merge trees in the graph based on the entering edges.
+
+ Parameters
+ ----------
+ enters : dict
+ A dictionary of entering edges for each node.
+ """
+ for v, (u, _, __) in enters.items():
+ if self.uf[u] != self.uf[v]:
+ self.G.remove_edge(u, v)
+ self.uf.union(u, v)
+
+ def _reconstruct_branching(self):
+ """
+ Reconstruct the branching from the contracted graph.
+
+ Returns
+ -------
+ nx.DiGraph
+ The reconstructed branching.
+ """
+ H = nx.DiGraph()
+ H.add_nodes_from(self.G_original)
+
+ for u in self.G:
+ for v, _, data in self.G.in_edges(u, data=True):
+ original_u = next(n for n in self.G_original if self.uf[n] == self.uf[u])
+ original_v = next(n for n in self.G_original if self.uf[n] == self.uf[v])
+ H.add_edge(original_v, original_u, **data)
+
+ return H
+
+
@nx._dispatchable(preserve_edge_attrs=True, mutates_input=True, returns_graph=True)
def minimal_branching(G, /, *, attr='weight', default=1, preserve_attrs=False, partition=None):
@@ -373,7 +568,22 @@ class ArborescenceIterator:
partition_arborescence : nx.Graph
The minimum spanning arborescence of the input partition.
"""
- pass
+ import copy
+ for e in self.G.edges():
+ if e not in partition_arborescence.edges():
+ new_partition = copy.deepcopy(partition)
+ new_partition.partition_dict[e] = nx.EdgePartition.INCLUDED
+ self._write_partition(new_partition)
+ new_mst_weight = self.method(self.G, self.weight, partition=self.partition_key, preserve_attrs=True).size(weight=self.weight)
+ new_partition.mst_weight = new_mst_weight if self.minimum else -new_mst_weight
+ self.partition_queue.put(new_partition)
+
+ new_partition = copy.deepcopy(partition)
+ new_partition.partition_dict[e] = nx.EdgePartition.EXCLUDED
+ self._write_partition(new_partition)
+ new_mst_weight = self.method(self.G, self.weight, partition=self.partition_key, preserve_attrs=True).size(weight=self.weight)
+ new_partition.mst_weight = new_mst_weight if self.minimum else -new_mst_weight
+ self.partition_queue.put(new_partition)
def _write_partition(self, partition):
"""
@@ -388,10 +598,20 @@ class ArborescenceIterator:
A Partition dataclass describing a partition on the edges of the
graph.
"""
- pass
+ self._clear_partition(self.G)
+ for e, status in partition.partition_dict.items():
+ self.G.edges[e][self.partition_key] = status
+ for v in self.G:
+ in_edges = list(self.G.in_edges(v))
+ if any(self.G.edges[e].get(self.partition_key) == nx.EdgePartition.INCLUDED for e in in_edges):
+ for e in in_edges:
+ if self.G.edges[e].get(self.partition_key) != nx.EdgePartition.INCLUDED:
+ self.G.edges[e][self.partition_key] = nx.EdgePartition.EXCLUDED
def _clear_partition(self, G):
"""
Removes partition data from the graph
"""
- pass
\ No newline at end of file
+ for e in G.edges():
+ if self.partition_key in G.edges[e]:
+ del G.edges[e][self.partition_key]