Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions invokeai/app/cli/completer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""
Readline helper functions for cli_app.py
You may import the global singleton `completer` to get access to the
completer object.
"""
import atexit
import readline
import shlex

from pathlib import Path
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin

from ...backend import ModelManager, Globals
from ..invocations.baseinvocation import BaseInvocation
from .commands import BaseCommand

# singleton object, class variable
completer = None

class Completer(object):

def __init__(self, model_manager: ModelManager):
self.commands = self.get_commands()
self.matches = None
self.linebuffer = None
self.manager = model_manager
return

def complete(self, text, state):
"""
Complete commands and switches fromm the node CLI command line.
Switches are determined in a context-specific manner.
"""

buffer = readline.get_line_buffer()
if state == 0:
options = None
try:
current_command, current_switch = self.get_current_command(buffer)
options = self.get_command_options(current_command, current_switch)
except IndexError:
pass
options = options or list(self.parse_commands().keys())

if not text: # first time
self.matches = options
else:
self.matches = [s for s in options if s and s.startswith(text)]

try:
match = self.matches[state]
except IndexError:
match = None
return match

@classmethod
def get_commands(self)->List[object]:
"""
Return a list of all the client commands and invocations.
"""
return BaseCommand.get_commands() + BaseInvocation.get_invocations()

def get_current_command(self, buffer: str)->tuple[str, str]:
"""
Parse the readline buffer to find the most recent command and its switch.
"""
if len(buffer)==0:
return None, None
tokens = shlex.split(buffer)
command = None
switch = None
for t in tokens:
if t[0].isalpha():
if switch is None:
command = t
else:
switch = t
# don't try to autocomplete switches that are already complete
if switch and buffer.endswith(' '):
switch=None
return command or '', switch or ''

def parse_commands(self)->Dict[str, List[str]]:
"""
Return a dict in which the keys are the command name
and the values are the parameters the command takes.
"""
result = dict()
for command in self.commands:
hints = get_type_hints(command)
name = get_args(hints['type'])[0]
result.update({name:hints})
return result

def get_command_options(self, command: str, switch: str)->List[str]:
"""
Return all the parameters that can be passed to the command as
command-line switches. Returns None if the command is unrecognized.
"""
parsed_commands = self.parse_commands()
if command not in parsed_commands:
return None

# handle switches in the format "-foo=bar"
argument = None
if switch and '=' in switch:
switch, argument = switch.split('=')

parameter = switch.strip('-')
if parameter in parsed_commands[command]:
if argument is None:
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
else:
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
else:
return [f"--{x}" for x in parsed_commands[command].keys()]

def get_parameter_options(self, parameter: str, typehint)->List[str]:
"""
Given a parameter type (such as Literal), offers autocompletions.
"""
if get_origin(typehint) == Literal:
return get_args(typehint)
if parameter == 'model':
return self.manager.model_names()

def _pre_input_hook(self):
if self.linebuffer:
readline.insert_text(self.linebuffer)
readline.redisplay()
self.linebuffer = None

def set_autocompleter(model_manager: ModelManager) -> Completer:
global completer

if completer:
return completer

completer = Completer(model_manager)

readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
try:
readline.set_auto_history(True)
except:
pass
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(" ")
readline.parse_and_bind("tab: complete")
readline.parse_and_bind("set print-completions-horizontally off")
readline.parse_and_bind("set page-completions on")
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")

histfile = Path(Globals.root, ".invoke_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
except OSError: # file likely corrupted
newname = f"{histfile}.old"
print(
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
)
histfile.replace(Path(newname))
atexit.register(readline.write_history_file, histfile)
11 changes: 9 additions & 2 deletions invokeai/app/cli_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
from .cli.completer import set_autocompleter
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
Expand Down Expand Up @@ -130,6 +131,12 @@ def invoke_cli():
config.parse_args()
model_manager = get_model_manager(config)

# This initializes the autocompleter and returns it.
# Currently nothing is done with the returned Completer
# object, but the object can be used to change autocompletion
# behavior on the fly, if desired.
completer = set_autocompleter(model_manager)

events = EventServiceBase()

output_folder = os.path.abspath(
Expand Down Expand Up @@ -162,8 +169,8 @@ def invoke_cli():

while True:
try:
cmd_input = input("> ")
except KeyboardInterrupt:
cmd_input = input("invoke> ")
except (KeyboardInterrupt, EOFError):
# Ctrl-c exits
break

Expand Down