Skip to content

Commit

Permalink
Fix test case
Browse files Browse the repository at this point in the history
  • Loading branch information
kracekumar committed Feb 23, 2025
1 parent e138676 commit 2c80aef
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 25 deletions.
6 changes: 5 additions & 1 deletion litecli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,13 @@ def one_iteration(text=None):
try:
start = time()
cur = self.sqlexecute.conn and self.sqlexecute.conn.cursor()
context, sql = special.handle_llm(text, cur)
click.echo("Calling llm command")
context, sql, duration = special.handle_llm(text, cur)
if context:
click.echo("LLM Reponse:")
click.echo(context)
click.echo('---')
click.echo(f"llm command took {duration:.2f} seconds to complete the operation")
text = self.prompt_app.prompt(default=sql)
except KeyboardInterrupt:
return
Expand Down
31 changes: 11 additions & 20 deletions litecli/packages/special/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,21 +192,8 @@ def ensure_litecli_template(replace=False):
run_external_cmd("llm", PROMPT, "--save", "litecli")
return


@contextlib.contextmanager
def timer():
start = time.perf_counter()
try:
click.echo("Calling llm command")
yield # Code inside the 'with' block runs here
finally:
end = time.perf_counter()
elapsed = end - start
click.echo(f"llm command took: {elapsed:.6f} seconds")


@export
def handle_llm(text, cur) -> Tuple[str, Optional[str]]:
def handle_llm(text, cur) -> Tuple[str, Optional[str], float]:
"""This function handles the special command `\\llm`.
If it deals with a question that results in a SQL query then it will return
Expand Down Expand Up @@ -267,27 +254,31 @@ def handle_llm(text, cur) -> Tuple[str, Optional[str]]:
if not use_context:
args = parts
if capture_output:
with timer():
_, result = run_external_cmd("llm", *args, capture_output=capture_output)
start = time.perf_counter()
_, result = run_external_cmd("llm", *args, capture_output=capture_output)
end = time.perf_counter()
match = re.search(_SQL_CODE_FENCE, result, re.DOTALL)
if match:
sql = match.group(1).strip()
else:
output = [(None, None, None, result)]
raise FinishIteration(output)

return result if verbose else "", sql
return result if verbose else "", sql, end - start
else:
run_external_cmd("llm", *args, restart_cli=restart)
raise FinishIteration(None)

try:
ensure_litecli_template()
with timer():
context, sql = sql_using_llm(cur=cur, question=arg, verbose=verbose)
# Measure end to end llm command invocation.
# This measures the internal DB command to pull the schema
start = time.perf_counter()
context, sql = sql_using_llm(cur=cur, question=arg, verbose=verbose)
end = time.perf_counter()
if not verbose:
context = ""
return context, sql
return context, sql, end - start
except Exception as e:
# Something went wrong. Raise an exception and bail.
raise RuntimeError(e)
Expand Down
13 changes: 9 additions & 4 deletions tests/test_llm_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,15 @@ def test_llm_command_with_c_flag_and_fenced_sql(mock_run_cmd, mock_llm, executor

test_text = r"\llm -c 'Rewrite the SQL without CTE'"

result, sql = handle_llm(test_text, executor)
result, sql, duration = handle_llm(test_text, executor)

# We expect the function to return (result, sql), but result might be "" if verbose is not set
# By default, `verbose` is false unless text has something like \llm --verbose?
# The function code: return result if verbose else "", sql
# Our test_text doesn't set verbose => we expect "" for the returned context.
assert result == ""
assert sql == "SELECT * FROM table;"
assert isinstance(duration, float)


@patch("litecli.packages.special.llm.llm")
Expand Down Expand Up @@ -133,7 +134,7 @@ def test_llm_command_with_prompt(mock_sql_using_llm, mock_ensure_template, mock_
mock_sql_using_llm.return_value = ("context from LLM", "SELECT 1;")

test_text = r"\llm prompt 'Magic happening here?'"
context, sql = handle_llm(test_text, executor)
context, sql, duration = handle_llm(test_text, executor)

# ensure_litecli_template should be called
mock_ensure_template.assert_called_once()
Expand All @@ -143,6 +144,7 @@ def test_llm_command_with_prompt(mock_sql_using_llm, mock_ensure_template, mock_
mock_sql_using_llm.assert_called()
assert context == ""
assert sql == "SELECT 1;"
assert isinstance(duration, float)


@patch("litecli.packages.special.llm.llm")
Expand All @@ -155,12 +157,13 @@ def test_llm_command_question_with_context(mock_sql_using_llm, mock_ensure_templ
mock_sql_using_llm.return_value = ("You have context!", "SELECT 2;")

test_text = r"\llm 'Top 10 downloads by size.'"
context, sql = handle_llm(test_text, executor)
context, sql, duration = handle_llm(test_text, executor)

mock_ensure_template.assert_called_once()
mock_sql_using_llm.assert_called()
assert context == ""
assert sql == "SELECT 2;"
assert isinstance(duration, float)


@patch("litecli.packages.special.llm.llm")
Expand All @@ -173,7 +176,9 @@ def test_llm_command_question_verbose(mock_sql_using_llm, mock_ensure_template,
mock_sql_using_llm.return_value = ("Verbose context, oh yeah!", "SELECT 42;")

test_text = r"\llm+ 'Top 10 downloads by size.'"
context, sql = handle_llm(test_text, executor)
context, sql, duration = handle_llm(test_text, executor)

assert context == "Verbose context, oh yeah!"
assert sql == "SELECT 42;"

assert isinstance(duration, float)

0 comments on commit 2c80aef

Please sign in to comment.