77import networkx as nx
88import matplotlib .pyplot as plt
99
10- from ..models .image import ImageField
11- from ..services .graph import GraphExecutionState
10+ from ..invocations .baseinvocation import BaseInvocation
11+ from ..invocations .image import ImageField
12+ from ..services .graph import GraphExecutionState , LibraryGraph , GraphInvocation , Edge
1213from ..services .invoker import Invoker
1314
1415
16+ def add_field_argument (command_parser , name : str , field , default_override = None ):
17+ default = default_override if default_override is not None else field .default if field .default_factory is None else field .default_factory ()
18+ if get_origin (field .type_ ) == Literal :
19+ allowed_values = get_args (field .type_ )
20+ allowed_types = set ()
21+ for val in allowed_values :
22+ allowed_types .add (type (val ))
23+ allowed_types_list = list (allowed_types )
24+ field_type = allowed_types_list [0 ] if len (allowed_types ) == 1 else Union [allowed_types_list ] # type: ignore
25+
26+ command_parser .add_argument (
27+ f"--{ name } " ,
28+ dest = name ,
29+ type = field_type ,
30+ default = default ,
31+ choices = allowed_values ,
32+ help = field .field_info .description ,
33+ )
34+ else :
35+ command_parser .add_argument (
36+ f"--{ name } " ,
37+ dest = name ,
38+ type = field .type_ ,
39+ default = default ,
40+ help = field .field_info .description ,
41+ )
42+
43+
1544def add_parsers (
1645 subparsers ,
1746 commands : list [type ],
@@ -36,48 +65,65 @@ def add_parsers(
3665 if name in exclude_fields :
3766 continue
3867
39- if get_origin (field .type_ ) == Literal :
40- allowed_values = get_args (field .type_ )
41- allowed_types = set ()
42- for val in allowed_values :
43- allowed_types .add (type (val ))
44- allowed_types_list = list (allowed_types )
45- field_type = allowed_types_list [0 ] if len (allowed_types ) == 1 else Union [allowed_types_list ] # type: ignore
46-
47- command_parser .add_argument (
48- f"--{ name } " ,
49- dest = name ,
50- type = field_type ,
51- default = field .default if field .default_factory is None else field .default_factory (),
52- choices = allowed_values ,
53- help = field .field_info .description ,
54- )
55- else :
56- command_parser .add_argument (
57- f"--{ name } " ,
58- dest = name ,
59- type = field .type_ ,
60- default = field .default if field .default_factory is None else field .default_factory (),
61- help = field .field_info .description ,
62- )
68+ add_field_argument (command_parser , name , field )
69+
70+
71+ def add_graph_parsers (
72+ subparsers ,
73+ graphs : list [LibraryGraph ],
74+ add_arguments : Callable [[argparse .ArgumentParser ], None ]| None = None
75+ ):
76+ for graph in graphs :
77+ command_parser = subparsers .add_parser (graph .name , help = graph .description )
78+
79+ if add_arguments is not None :
80+ add_arguments (command_parser )
81+
82+ # Add arguments for inputs
83+ for exposed_input in graph .exposed_inputs :
84+ node = graph .graph .get_node (exposed_input .node_path )
85+ field = node .__fields__ [exposed_input .field ]
86+ default_override = getattr (node , exposed_input .field )
87+ add_field_argument (command_parser , exposed_input .alias , field , default_override )
6388
6489
6590class CliContext :
6691 invoker : Invoker
6792 session : GraphExecutionState
6893 parser : argparse .ArgumentParser
6994 defaults : dict [str , Any ]
95+ graph_nodes : dict [str , str ]
96+ nodes_added : list [str ]
7097
7198 def __init__ (self , invoker : Invoker , session : GraphExecutionState , parser : argparse .ArgumentParser ):
7299 self .invoker = invoker
73100 self .session = session
74101 self .parser = parser
75102 self .defaults = dict ()
103+ self .graph_nodes = dict ()
104+ self .nodes_added = list ()
76105
77106 def get_session (self ):
78107 self .session = self .invoker .services .graph_execution_manager .get (self .session .id )
79108 return self .session
80109
110+ def reset (self ):
111+ self .session = self .invoker .create_execution_state ()
112+ self .graph_nodes = dict ()
113+ self .nodes_added = list ()
114+ # Leave defaults unchanged
115+
116+ def add_node (self , node : BaseInvocation ):
117+ self .get_session ()
118+ self .session .graph .add_node (node )
119+ self .nodes_added .append (node .id )
120+ self .invoker .services .graph_execution_manager .set (self .session )
121+
122+ def add_edge (self , edge : Edge ):
123+ self .get_session ()
124+ self .session .add_edge (edge )
125+ self .invoker .services .graph_execution_manager .set (self .session )
126+
81127
82128class ExitCli (Exception ):
83129 """Exception to exit the CLI"""
0 commit comments