Skip to content

Commit

Permalink
change run to use args and kwargs (langchain-ai#367)
Browse files Browse the repository at this point in the history
Before, `run` was not able to be called with multiple arguments. This
expands the functionality.
  • Loading branch information
agola11 authored Dec 18, 2022
1 parent a7084ad commit 8d0869c
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 10 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[flake8]
exclude =
venv
.venv
__pycache__
notebooks
Expand Down
25 changes: 16 additions & 9 deletions langchain/chains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,23 @@ def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs) for inputs in input_list]

def run(self, text: str) -> str:
"""Run text in, text out (if applicable)."""
if len(self.input_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one input key, got {self.input_keys}."
)
def run(self, *args: str, **kwargs: str) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key, got {self.output_keys}."
f"one output key. Got {self.output_keys}."
)
return self({self.input_keys[0]: text})[self.output_keys[0]]

if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return self(args[0])[self.output_keys[0]]

if kwargs and not args:
return self(kwargs)[self.output_keys[0]]

raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
45 changes: 44 additions & 1 deletion tests/unit_tests/chains/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class FakeChain(Chain, BaseModel):

be_correct: bool = True
the_input_keys: List[str] = ["foo"]
the_output_keys: List[str] = ["bar"]

@property
def input_keys(self) -> List[str]:
Expand All @@ -21,7 +22,7 @@ def input_keys(self) -> List[str]:
@property
def output_keys(self) -> List[str]:
"""Output key of bar."""
return ["bar"]
return self.the_output_keys

def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
if self.be_correct:
Expand Down Expand Up @@ -63,3 +64,45 @@ def test_single_input_error() -> None:
chain = FakeChain(the_input_keys=["foo", "bar"])
with pytest.raises(ValueError):
chain("bar")


def test_run_single_arg() -> None:
"""Test run method with single arg."""
chain = FakeChain()
output = chain.run("bar")
assert output == "baz"


def test_run_multiple_args_error() -> None:
"""Test run method with multiple args errors as expected."""
chain = FakeChain()
with pytest.raises(ValueError):
chain.run("bar", "foo")


def test_run_kwargs() -> None:
"""Test run method with kwargs."""
chain = FakeChain(the_input_keys=["foo", "bar"])
output = chain.run(foo="bar", bar="foo")
assert output == "baz"


def test_run_kwargs_error() -> None:
"""Test run method with kwargs errors as expected."""
chain = FakeChain(the_input_keys=["foo", "bar"])
with pytest.raises(ValueError):
chain.run(foo="bar", baz="foo")


def test_run_args_and_kwargs_error() -> None:
"""Test run method with args and kwargs."""
chain = FakeChain(the_input_keys=["foo", "bar"])
with pytest.raises(ValueError):
chain.run("bar", foo="bar")


def test_multiple_output_keys_error() -> None:
"""Test run with multiple output keys errors as expected."""
chain = FakeChain(the_output_keys=["foo", "bar"])
with pytest.raises(ValueError):
chain.run("bar")

0 comments on commit 8d0869c

Please sign in to comment.