Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Upcoming (TBD)
==============

Features
--------
* Limit size of LLM prompts and cache LLM prompt data.


Internal
--------
* Include LLM dependencies in tox configuration.
Expand Down
3 changes: 2 additions & 1 deletion mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,9 +795,10 @@ def one_iteration(text: str | None = None) -> None:
while special.is_llm_command(text):
start = time()
try:
assert isinstance(self.sqlexecute, SQLExecute)
assert sqlexecute.conn is not None
cur = sqlexecute.conn.cursor()
context, sql, duration = special.handle_llm(text, cur)
context, sql, duration = special.handle_llm(text, cur, sqlexecute.dbname or '')
if context:
click.echo("LLM Response:")
click.echo(context)
Expand Down
74 changes: 55 additions & 19 deletions mycli/packages/special/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def cli_commands() -> list[str]:
return list(cli.commands.keys())


def handle_llm(text: str, cur: Cursor) -> tuple[str, str | None, float]:
def handle_llm(text: str, cur: Cursor, dbname: str) -> tuple[str, str | None, float]:
_, verbosity, arg = parse_special_command(text)
if not LLM_IMPORTED:
output = [(None, None, None, NEED_DEPENDENCIES)]
Expand Down Expand Up @@ -261,7 +261,7 @@ def handle_llm(text: str, cur: Cursor) -> tuple[str, str | None, float]:
try:
ensure_mycli_template()
start = time()
context, sql = sql_using_llm(cur=cur, question=arg)
context, sql = sql_using_llm(cur=cur, question=arg, dbname=dbname)
end = time()
if verbosity == Verbosity.SUCCINCT:
context = ""
Expand All @@ -275,45 +275,81 @@ def is_llm_command(command: str) -> bool:
return cmd in ("\\llm", "\\ai")


def sql_using_llm(
cur: Cursor | None,
question: str | None = None,
) -> tuple[str, str | None]:
if cur is None:
raise RuntimeError("Connect to a database and try again.")
schema_query = """
SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')')
def truncate_list_elements(row: list) -> list:
target_size = 100000
width = 1024
while width >= 0:
truncated_row = [x[:width] if isinstance(x, (str, bytes)) else x for x in row]
if sum(sys.getsizeof(x) for x in truncated_row) <= target_size:
break
width -= 100
return truncated_row


def truncate_table_lines(table: list[str]) -> list[str]:
target_size = 100000
truncated_table = []
running_sum = 0
while table and running_sum <= target_size:
line = table.pop(0)
running_sum += sys.getsizeof(line)
truncated_table.append(line)
return truncated_table


@functools.cache
def get_schema(cur: Cursor, dbname: str) -> str:
click.echo("Preparing schema information to feed the LLM")
schema_query = f"""
SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')') AS schema
FROM information_schema.columns
WHERE table_schema = DATABASE()
WHERE table_schema = '{dbname}'
GROUP BY table_name
ORDER BY table_name
"""
tables_query = "SHOW TABLES"
sample_row_query = "SELECT * FROM `{table}` LIMIT 1"
click.echo("Preparing schema information to feed the llm")
cur.execute(schema_query)
db_schema = "\n".join([row[0] for (row,) in cur.fetchall()])
db_schema = [row[0] for (row,) in cur.fetchall()]
return '\n'.join(truncate_table_lines(db_schema))


@functools.cache
def get_sample_data(cur: Cursor, dbname: str) -> dict[str, Any]:
click.echo("Preparing sample data to feed the LLM")
tables_query = "SHOW TABLES"
sample_row_query = "SELECT * FROM `{dbname}`.`{table}` LIMIT 1"
cur.execute(tables_query)
sample_data = {}
for (table_name,) in cur.fetchall():
try:
cur.execute(sample_row_query.format(table=table_name))
cur.execute(sample_row_query.format(dbname=dbname, table=table_name))
except Exception:
continue
cols = [desc[0] for desc in cur.description]
row = cur.fetchone()
if row is None:
continue
sample_data[table_name] = list(zip(cols, row))
sample_data[table_name] = list(zip(cols, truncate_list_elements(list(row))))
return sample_data


def sql_using_llm(
cur: Cursor | None,
question: str | None,
dbname: str = '',
) -> tuple[str, str | None]:
if cur is None:
raise RuntimeError("Connect to a database and try again.")
if dbname == '':
raise RuntimeError("Choose a schema and try again.")
args = [
"--template",
LLM_TEMPLATE_NAME,
"--param",
"db_schema",
db_schema,
get_schema(cur, dbname),
"--param",
"sample_data",
sample_data,
get_sample_data(cur, dbname),
"--param",
"question",
question,
Expand Down
22 changes: 11 additions & 11 deletions test/test_llm_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_llm_command_without_args(mock_llm, executor):
assert mock_llm is not None
test_text = r"\llm"
with pytest.raises(FinishIteration) as exc_info:
handle_llm(test_text, executor)
handle_llm(test_text, executor, 'mysql')
# Should return usage message when no args provided
assert exc_info.value.args[0] == [(None, None, None, USAGE)]

Expand All @@ -38,7 +38,7 @@ def test_llm_command_with_c_flag(mock_run_cmd, mock_llm, executor):
mock_run_cmd.return_value = (0, "Hello, no SQL today.")
test_text = r"\llm -c 'Something?'"
with pytest.raises(FinishIteration) as exc_info:
handle_llm(test_text, executor)
handle_llm(test_text, executor, 'mysql')
# Expect raw output when no SQL fence found
assert exc_info.value.args[0] == [(None, None, None, "Hello, no SQL today.")]

Expand All @@ -51,7 +51,7 @@ def test_llm_command_with_c_flag_and_fenced_sql(mock_run_cmd, mock_llm, executor
fenced = f"Here you go:\n```sql\n{sql_text}\n```"
mock_run_cmd.return_value = (0, fenced)
test_text = r"\llm -c 'Rewrite SQL'"
result, sql, duration = handle_llm(test_text, executor)
result, sql, duration = handle_llm(test_text, executor, 'mysql')
# Without verbose, result is empty, sql extracted
assert sql == sql_text
assert result == ""
Expand All @@ -64,7 +64,7 @@ def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor):
# 'models' is a known subcommand
test_text = r"\llm models"
with pytest.raises(FinishIteration) as exc_info:
handle_llm(test_text, executor)
handle_llm(test_text, executor, 'mysql')
mock_run_cmd.assert_called_once_with("llm", "models", restart_cli=False)
assert exc_info.value.args[0] is None

Expand All @@ -74,7 +74,7 @@ def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor):
def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor):
test_text = r"\llm --help"
with pytest.raises(FinishIteration) as exc_info:
handle_llm(test_text, executor)
handle_llm(test_text, executor, 'mysql')
mock_run_cmd.assert_called_once_with("llm", "--help", restart_cli=False)
assert exc_info.value.args[0] is None

Expand All @@ -84,7 +84,7 @@ def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor):
def test_llm_command_with_install_flag(mock_run_cmd, mock_llm, executor):
test_text = r"\llm install openai"
with pytest.raises(FinishIteration) as exc_info:
handle_llm(test_text, executor)
handle_llm(test_text, executor, 'mysql')
mock_run_cmd.assert_called_once_with("llm", "install", "openai", restart_cli=True)
assert exc_info.value.args[0] is None

Expand All @@ -98,7 +98,7 @@ def test_llm_command_with_prompt(mock_sql_using_llm, mock_ensure_template, mock_
"""
mock_sql_using_llm.return_value = ("CTX", "SELECT 1;")
test_text = r"\llm prompt 'Test?'"
context, sql, duration = handle_llm(test_text, executor)
context, sql, duration = handle_llm(test_text, executor, 'mysql')
mock_ensure_template.assert_called_once()
mock_sql_using_llm.assert_called()
assert context == "CTX"
Expand All @@ -115,7 +115,7 @@ def test_llm_command_question_with_context(mock_sql_using_llm, mock_ensure_templ
"""
mock_sql_using_llm.return_value = ("CTX2", "SELECT 2;")
test_text = r"\llm 'Top 10?'"
context, sql, duration = handle_llm(test_text, executor)
context, sql, duration = handle_llm(test_text, executor, 'mysql')
mock_ensure_template.assert_called_once()
mock_sql_using_llm.assert_called()
assert context == "CTX2"
Expand All @@ -132,7 +132,7 @@ def test_llm_command_question_verbose(mock_sql_using_llm, mock_ensure_template,
"""
mock_sql_using_llm.return_value = ("NO_CTX", "SELECT 42;")
test_text = r"\llm- 'Succinct?'"
context, sql, duration = handle_llm(test_text, executor)
context, sql, duration = handle_llm(test_text, executor, 'mysql')
assert context == ""
assert sql == "SELECT 42;"
assert isinstance(duration, float)
Expand Down Expand Up @@ -181,7 +181,7 @@ def fetchone(self):
sql_text = "SELECT 1, 'abc';"
fenced = f"Note\n```sql\n{sql_text}\n```"
mock_run_cmd.return_value = (0, fenced)
result, sql = sql_using_llm(dummy_cur, question="dummy")
result, sql = sql_using_llm(dummy_cur, question="dummy", dbname='mysql')
assert result == fenced
assert sql == sql_text

Expand All @@ -194,5 +194,5 @@ def test_handle_llm_aliases_without_args(prefix, executor, monkeypatch):

monkeypatch.setattr(llm_module, "llm", object())
with pytest.raises(FinishIteration) as exc_info:
handle_llm(prefix, executor)
handle_llm(prefix, executor, 'mysql')
assert exc_info.value.args[0] == [(None, None, None, USAGE)]