Skip to content

Commit 23d65e7

Browse files
authored
[nodes] Add subgraph library, subgraph usage in CLI, and fix subgraph execution (#3180)
* Add latent to latent (img2img equivalent) Fix a CLI bug with multiple links per node * Using "latents" instead of "latent" * [nodes] In-progress implementation of graph library * Add linking to CLI for graph nodes (still broken) * Fix subgraph execution, fix subgraph linking in CLI * Fix LatentsToLatents
1 parent 024fd54 commit 23d65e7

File tree

13 files changed

+472
-104
lines changed

13 files changed

+472
-104
lines changed

invokeai/app/api/dependencies.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import os
44
from argparse import Namespace
55

6+
from ..services.default_graphs import create_system_graphs
7+
68
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
79

810
from ...backend import Globals
911
from ..services.model_manager_initializer import get_model_manager
1012
from ..services.restoration_services import RestorationServices
11-
from ..services.graph import GraphExecutionState
13+
from ..services.graph import GraphExecutionState, LibraryGraph
1214
from ..services.image_storage import DiskImageStorage
1315
from ..services.invocation_queue import MemoryInvocationQueue
1416
from ..services.invocation_services import InvocationServices
@@ -69,13 +71,18 @@ def initialize(config, event_handler_id: int):
6971
latents=latents,
7072
images=images,
7173
queue=MemoryInvocationQueue(),
74+
graph_library=SqliteItemStorage[LibraryGraph](
75+
filename=db_location, table_name="graphs"
76+
),
7277
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
7378
filename=db_location, table_name="graph_executions"
7479
),
7580
processor=DefaultInvocationProcessor(),
7681
restoration=RestorationServices(config),
7782
)
7883

84+
create_system_graphs(services.graph_library)
85+
7986
ApiDependencies.invoker = Invoker(services)
8087

8188
@staticmethod

invokeai/app/cli/commands.py

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,40 @@
77
import networkx as nx
88
import matplotlib.pyplot as plt
99

10-
from ..models.image import ImageField
11-
from ..services.graph import GraphExecutionState
10+
from ..invocations.baseinvocation import BaseInvocation
11+
from ..invocations.image import ImageField
12+
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
1213
from ..services.invoker import Invoker
1314

1415

16+
def add_field_argument(command_parser, name: str, field, default_override = None):
17+
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
18+
if get_origin(field.type_) == Literal:
19+
allowed_values = get_args(field.type_)
20+
allowed_types = set()
21+
for val in allowed_values:
22+
allowed_types.add(type(val))
23+
allowed_types_list = list(allowed_types)
24+
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
25+
26+
command_parser.add_argument(
27+
f"--{name}",
28+
dest=name,
29+
type=field_type,
30+
default=default,
31+
choices=allowed_values,
32+
help=field.field_info.description,
33+
)
34+
else:
35+
command_parser.add_argument(
36+
f"--{name}",
37+
dest=name,
38+
type=field.type_,
39+
default=default,
40+
help=field.field_info.description,
41+
)
42+
43+
1544
def add_parsers(
1645
subparsers,
1746
commands: list[type],
@@ -36,48 +65,65 @@ def add_parsers(
3665
if name in exclude_fields:
3766
continue
3867

39-
if get_origin(field.type_) == Literal:
40-
allowed_values = get_args(field.type_)
41-
allowed_types = set()
42-
for val in allowed_values:
43-
allowed_types.add(type(val))
44-
allowed_types_list = list(allowed_types)
45-
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
46-
47-
command_parser.add_argument(
48-
f"--{name}",
49-
dest=name,
50-
type=field_type,
51-
default=field.default if field.default_factory is None else field.default_factory(),
52-
choices=allowed_values,
53-
help=field.field_info.description,
54-
)
55-
else:
56-
command_parser.add_argument(
57-
f"--{name}",
58-
dest=name,
59-
type=field.type_,
60-
default=field.default if field.default_factory is None else field.default_factory(),
61-
help=field.field_info.description,
62-
)
68+
add_field_argument(command_parser, name, field)
69+
70+
71+
def add_graph_parsers(
72+
subparsers,
73+
graphs: list[LibraryGraph],
74+
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
75+
):
76+
for graph in graphs:
77+
command_parser = subparsers.add_parser(graph.name, help=graph.description)
78+
79+
if add_arguments is not None:
80+
add_arguments(command_parser)
81+
82+
# Add arguments for inputs
83+
for exposed_input in graph.exposed_inputs:
84+
node = graph.graph.get_node(exposed_input.node_path)
85+
field = node.__fields__[exposed_input.field]
86+
default_override = getattr(node, exposed_input.field)
87+
add_field_argument(command_parser, exposed_input.alias, field, default_override)
6388

6489

6590
class CliContext:
6691
invoker: Invoker
6792
session: GraphExecutionState
6893
parser: argparse.ArgumentParser
6994
defaults: dict[str, Any]
95+
graph_nodes: dict[str, str]
96+
nodes_added: list[str]
7097

7198
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
7299
self.invoker = invoker
73100
self.session = session
74101
self.parser = parser
75102
self.defaults = dict()
103+
self.graph_nodes = dict()
104+
self.nodes_added = list()
76105

77106
def get_session(self):
78107
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
79108
return self.session
80109

110+
def reset(self):
111+
self.session = self.invoker.create_execution_state()
112+
self.graph_nodes = dict()
113+
self.nodes_added = list()
114+
# Leave defaults unchanged
115+
116+
def add_node(self, node: BaseInvocation):
117+
self.get_session()
118+
self.session.graph.add_node(node)
119+
self.nodes_added.append(node.id)
120+
self.invoker.services.graph_execution_manager.set(self.session)
121+
122+
def add_edge(self, edge: Edge):
123+
self.get_session()
124+
self.session.add_edge(edge)
125+
self.invoker.services.graph_execution_manager.set(self.session)
126+
81127

82128
class ExitCli(Exception):
83129
"""Exception to exit the CLI"""

0 commit comments

Comments
 (0)