Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Accept capitals in ask_yes_no (#443)
Browse files Browse the repository at this point in the history
Co-authored-by: PCSwingle <[email protected]>
  • Loading branch information
jakethekoenig and PCSwingle authored Jan 2, 2024
1 parent 69410b9 commit a955b58
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mentat/python_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def _call_mentat(self, message: str):
source=StreamMessageSource.CLIENT,
channel=f"input_request:{input_request_message.id}",
)
if message == "y":
if message.strip().lower() == "y":
await self.wait_for_edit_completion()

temp = self._accumulated_message
Expand Down
2 changes: 1 addition & 1 deletion mentat/session_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def ask_yes_no(default_yes: bool) -> bool:
# TODO: combine this into a single message (include content)
stream.send("(Y/n)" if default_yes else "(y/N)")
response = await collect_user_input(plain=True)
content = response.data
content = response.data.strip().lower()
if content in ["y", "n", ""]:
break
return content == "y" or (content != "n" and default_yes)
Expand Down
46 changes: 46 additions & 0 deletions tests/session_input_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import asyncio

import pytest

from mentat.session_context import SESSION_CONTEXT
from mentat.session_input import ask_yes_no


@pytest.mark.asyncio
async def test_ask_yes_no():
stream = SESSION_CONTEXT.get().stream
stream.start()

# Test when user inputs invalid confirmation
ask_yes_no_task = asyncio.create_task(ask_yes_no(default_yes=False))
input_request_message = await stream.recv("input_request")
stream.send("yes", channel=f"input_request:{input_request_message.id}")
input_request_message = await stream.recv("input_request")
stream.send("y", channel=f"input_request:{input_request_message.id}")
assert await ask_yes_no_task is True

# Test when user inputs 'Y'
ask_yes_no_task = asyncio.create_task(ask_yes_no(default_yes=False))
input_request_message = await stream.recv("input_request")
stream.send("Y", channel=f"input_request:{input_request_message.id}")
assert await ask_yes_no_task is True

# Test when user inputs nothing
ask_yes_no_task = asyncio.create_task(ask_yes_no(default_yes=False))
input_request_message = await stream.recv("input_request")
stream.send("", channel=f"input_request:{input_request_message.id}")
assert await ask_yes_no_task is False

# Test when user inputs 'n'
ask_yes_no_task = asyncio.create_task(ask_yes_no(default_yes=False))
input_request_message = await stream.recv("input_request")
stream.send("n", channel=f"input_request:{input_request_message.id}")
assert await ask_yes_no_task is False

# Test default true
ask_yes_no_task = asyncio.create_task(ask_yes_no(default_yes=True))
input_request_message = await stream.recv("input_request")
stream.send("", channel=f"input_request:{input_request_message.id}")
assert await ask_yes_no_task is True

stream.stop()

0 comments on commit a955b58

Please sign in to comment.