Skip to content

Commit

Permalink
fix shell tool bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ruanrongman authored and Undertone0809 committed Sep 2, 2023
1 parent 7891106 commit 91a6054
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
10 changes: 8 additions & 2 deletions promptulate/client/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
from promptulate.agents import ToolAgent
from promptulate.llms import ErnieBot, ChatOpenAI, BaseLLM
from promptulate.schema import LLMType
from promptulate.tools import Calculator, DuckDuckGoTool, PythonREPLTool, ArxivQueryTool, SleepTool
from promptulate.tools import (
Calculator,
DuckDuckGoTool,
PythonREPLTool,
ArxivQueryTool,
SleepTool,
)
from promptulate.tools.shell import ShellTool
from promptulate.utils import set_proxy_mode, print_text

Expand All @@ -39,7 +45,7 @@
"Python Script Executor": PythonREPLTool,
"Arxiv Query": ArxivQueryTool,
"Sleep": SleepTool,
"Shell Executor": ShellTool
"Shell Executor": ShellTool,
}


Expand Down
20 changes: 16 additions & 4 deletions promptulate/tools/shell/api_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import subprocess
import sys
from io import StringIO

Expand All @@ -8,10 +9,21 @@ class ShellAPIWrapper:

@staticmethod
def run(command: str) -> str:
"""
Runs a command in a subprocess and returns
the output.
Args:
command: The command to run
"""
try:
result = os.popen(command)
output = result.read()
result.close()
except Exception as e:
output = subprocess.run(
command,
shell=True,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
).stdout.decode()
except subprocess.CalledProcessError as e:
output = repr(e)
return output
2 changes: 1 addition & 1 deletion tests/tools/test_shell_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_run(self):
command = """echo hello"""
result = api_wrapper.run(command)
print(result)
self.assertEqual("hello\n", result)
self.assertEqual("hello\r\n", result)


class TestShellReplTool(TestCase):
Expand Down

0 comments on commit 91a6054

Please sign in to comment.