4
4
from numpy .typing import ArrayLike
5
5
6
6
import pywhy_graphs
7
- import pywhy_graphs .networkx as pywhy_nx
8
7
from pywhy_graphs .classes .functions import edge_types
9
8
from pywhy_graphs .config import CLearnEndpoint , EdgeType
10
9
from pywhy_graphs .typing import Node
11
10
12
11
13
- def _graph_to_clearn_arr (G : pywhy_nx . MixedEdgeGraph ) -> Tuple [ArrayLike , List [Node ]]:
12
+ def _graph_to_clearn_arr (G ) -> Tuple [ArrayLike , List [Node ]]:
14
13
# define the array
15
14
arr = np .zeros ((G .number_of_nodes (), G .number_of_nodes ()), dtype = int )
16
15
@@ -125,9 +124,7 @@ def _graph_to_clearn_arr(G: pywhy_nx.MixedEdgeGraph) -> Tuple[ArrayLike, List[No
125
124
return arr , arr_idx
126
125
127
126
128
- def clearn_arr_to_graph (
129
- arr : ArrayLike , arr_idx : List [Node ], graph_type : str
130
- ) -> pywhy_nx .MixedEdgeGraph :
127
+ def clearn_arr_to_graph (arr : ArrayLike , arr_idx : List [Node ], graph_type : str ):
131
128
"""Convert causal-learn array to a graph object.
132
129
133
130
Parameters
@@ -170,7 +167,7 @@ def clearn_arr_to_graph(
170
167
elif graph_type == "admg" :
171
168
graph = pywhy_graphs .ADMG ()
172
169
elif graph_type == "cpdag" :
173
- graph = pywhy_graphs .CPDAG ()
170
+ graph = pywhy_graphs .CPDAG () # type: ignore
174
171
elif graph_type == "pag" :
175
172
graph = pywhy_graphs .PAG ()
176
173
else :
@@ -248,18 +245,21 @@ def clearn_arr_to_graph(
248
245
elif (endpoint_v == CLearnEndpoint .TAIL ) and (endpoint_u == CLearnEndpoint .TAIL ):
249
246
graph .add_edge (u , v , edge_type = graph .undirected_edge_name )
250
247
else :
248
+ if not hasattr (graph , "circle_edge_name" ):
249
+ raise RuntimeError (f"Graph { graph } is adding circular end points..." )
250
+
251
251
# Endpoints contain a circle...
252
252
# u o- v
253
253
if endpoint_u == CLearnEndpoint .CIRCLE :
254
- graph .add_edge (v , u , edge_type = graph .circle_edge_name )
254
+ graph .add_edge (v , u , edge_type = graph .circle_edge_name ) # type: ignore
255
255
elif endpoint_u == CLearnEndpoint .ARROW :
256
256
graph .add_edge (v , u , edge_type = graph .directed_edge_name )
257
257
elif endpoint_u == CLearnEndpoint .TAIL :
258
258
graph .add_edge (v , u , edge_type = graph .undirected_edge_name )
259
259
260
260
# u -o v
261
261
if endpoint_v == CLearnEndpoint .CIRCLE :
262
- graph .add_edge (u , v , edge_type = graph .circle_edge_name )
262
+ graph .add_edge (u , v , edge_type = graph .circle_edge_name ) # type: ignore
263
263
elif endpoint_v == CLearnEndpoint .ARROW :
264
264
graph .add_edge (u , v , edge_type = graph .directed_edge_name )
265
265
elif endpoint_v == CLearnEndpoint .TAIL :
@@ -271,7 +271,7 @@ def clearn_arr_to_graph(
271
271
272
272
273
273
def graph_to_arr (
274
- G : pywhy_nx . MixedEdgeGraph ,
274
+ G ,
275
275
format : str = "causal-learn" ,
276
276
node_order : Optional [ArrayLike ] = None ,
277
277
) -> Tuple [ArrayLike , List [Node ]]:
0 commit comments