forked from crewAIInc/crewAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add crew train cli (crewAIInc#624)
* fix: fix crewai-tools cli command * feat: add crewai train CLI command * feat: add the tests * fix: fix typing hinting issue on code * fix: test.yml * fix: fix test * fix: removed fix since it didnt changed the test
- Loading branch information
1 parent
a336381
commit 24ed8a2
Showing
16 changed files
with
278 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,12 @@ authors = ["Your Name <[email protected]>"] | |
|
||
[tool.poetry.dependencies] | ||
python = ">=3.10,<=3.13" | ||
crewai = {extras = ["tools"], version = "^0.30.11"} | ||
crewai = { extras = ["tools"], version = "^0.30.11" } | ||
|
||
[tool.poetry.scripts] | ||
{{folder_name}} = "{{folder_name}}.main:run" | ||
train = "{{folder_name}}.main:train" | ||
|
||
[build-system] | ||
requires = ["poetry-core"] | ||
build-backend = "poetry.core.masonry.api" | ||
build-backend = "poetry.core.masonry.api" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import subprocess | ||
|
||
import click | ||
|
||
|
||
def train_crew(n_iterations: int) -> None: | ||
""" | ||
Train the crew by running a command in the Poetry environment. | ||
Args: | ||
n_iterations (int): The number of iterations to train the crew. | ||
""" | ||
command = ["poetry", "run", "train", str(n_iterations)] | ||
|
||
try: | ||
if n_iterations <= 0: | ||
raise ValueError("The number of iterations must be a positive integer.") | ||
|
||
result = subprocess.run(command, capture_output=False, text=True, check=True) | ||
|
||
if result.stderr: | ||
click.echo(result.stderr, err=True) | ||
|
||
except subprocess.CalledProcessError as e: | ||
click.echo(f"An error occurred while training the crew: {e}", err=True) | ||
click.echo(e.output, err=True) | ||
|
||
except Exception as e: | ||
click.echo(f"An unexpected error occurred: {e}", err=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from unittest import mock | ||
|
||
import pytest | ||
from click.testing import CliRunner | ||
|
||
from crewai.cli.cli import train, version | ||
|
||
|
||
@pytest.fixture | ||
def runner(): | ||
return CliRunner() | ||
|
||
|
||
@mock.patch("crewai.cli.cli.train_crew") | ||
def test_train_default_iterations(train_crew, runner): | ||
result = runner.invoke(train) | ||
|
||
train_crew.assert_called_once_with(5) | ||
assert result.exit_code == 0 | ||
assert "Training the crew for 5 iterations" in result.output | ||
|
||
|
||
@mock.patch("crewai.cli.cli.train_crew") | ||
def test_train_custom_iterations(train_crew, runner): | ||
result = runner.invoke(train, ["--n_iterations", "10"]) | ||
|
||
train_crew.assert_called_once_with(10) | ||
assert result.exit_code == 0 | ||
assert "Training the crew for 10 iterations" in result.output | ||
|
||
|
||
@mock.patch("crewai.cli.cli.train_crew") | ||
def test_train_invalid_string_iterations(train_crew, runner): | ||
result = runner.invoke(train, ["--n_iterations", "invalid"]) | ||
|
||
train_crew.assert_not_called() | ||
assert result.exit_code == 2 | ||
assert ( | ||
"Usage: train [OPTIONS]\nTry 'train --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n" | ||
in result.output | ||
) | ||
|
||
|
||
def test_version_command(runner): | ||
result = runner.invoke(version) | ||
|
||
assert result.exit_code == 0 | ||
assert "crewai version:" in result.output | ||
|
||
|
||
def test_version_command_with_tools(runner): | ||
result = runner.invoke(version, ["--tools"]) | ||
|
||
assert result.exit_code == 0 | ||
assert "crewai version:" in result.output | ||
assert ( | ||
"crewai tools version:" in result.output | ||
or "crewai tools not installed" in result.output | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import subprocess | ||
from unittest import mock | ||
|
||
from crewai.cli.train_crew import train_crew | ||
|
||
|
||
@mock.patch("crewai.cli.train_crew.subprocess.run") | ||
def test_train_crew_positive_iterations(mock_subprocess_run): | ||
# Arrange | ||
n_iterations = 5 | ||
mock_subprocess_run.return_value = subprocess.CompletedProcess( | ||
args=["poetry", "run", "train", str(n_iterations)], | ||
returncode=0, | ||
stdout="Success", | ||
stderr="", | ||
) | ||
|
||
# Act | ||
train_crew(n_iterations) | ||
|
||
# Assert | ||
mock_subprocess_run.assert_called_once_with( | ||
["poetry", "run", "train", str(n_iterations)], | ||
capture_output=False, | ||
text=True, | ||
check=True, | ||
) | ||
|
||
|
||
@mock.patch("crewai.cli.train_crew.click") | ||
def test_train_crew_zero_iterations(click): | ||
train_crew(0) | ||
click.echo.assert_called_once_with( | ||
"An unexpected error occurred: The number of iterations must be a positive integer.", | ||
err=True, | ||
) | ||
|
||
|
||
@mock.patch("crewai.cli.train_crew.click") | ||
def test_train_crew_negative_iterations(click): | ||
train_crew(-2) | ||
click.echo.assert_called_once_with( | ||
"An unexpected error occurred: The number of iterations must be a positive integer.", | ||
err=True, | ||
) | ||
|
||
|
||
@mock.patch("crewai.cli.train_crew.click") | ||
@mock.patch("crewai.cli.train_crew.subprocess.run") | ||
def test_train_crew_called_process_error(mock_subprocess_run, click): | ||
n_iterations = 5 | ||
mock_subprocess_run.side_effect = subprocess.CalledProcessError( | ||
returncode=1, | ||
cmd=["poetry", "run", "train", str(n_iterations)], | ||
output="Error", | ||
stderr="Some error occurred", | ||
) | ||
train_crew(n_iterations) | ||
|
||
mock_subprocess_run.assert_called_once_with( | ||
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True | ||
) | ||
click.echo.assert_has_calls( | ||
[ | ||
mock.call.echo( | ||
"An error occurred while training the crew: Command '['poetry', 'run', 'train', '5']' returned non-zero exit status 1.", | ||
err=True, | ||
), | ||
mock.call.echo("Error", err=True), | ||
] | ||
) | ||
|
||
|
||
@mock.patch("crewai.cli.train_crew.click") | ||
@mock.patch("crewai.cli.train_crew.subprocess.run") | ||
def test_train_crew_unexpected_exception(mock_subprocess_run, click): | ||
# Arrange | ||
n_iterations = 5 | ||
mock_subprocess_run.side_effect = Exception("Unexpected error") | ||
train_crew(n_iterations) | ||
|
||
mock_subprocess_run.assert_called_once_with( | ||
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True | ||
) | ||
click.echo.assert_called_once_with( | ||
"An unexpected error occurred: Unexpected error", err=True | ||
) |
Oops, something went wrong.