diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 99868b6778..37bd8c62d0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,10 +23,9 @@ jobs: - name: Install Requirements run: | - sudo apt-get update && - pip install poetry && + pip install poetry poetry lock && poetry install - name: Run tests - run: poetry run pytest + run: poetry run pytest tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 428b729a86..dfa210d60f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,6 +2,8 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.4.4 hooks: - # Run the linter. - id: ruff - args: [--fix] + args: ["--fix"] + exclude: "templates" + - id: ruff-format + exclude: "templates" diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 8bec5af45e..93733419ec 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -2,6 +2,7 @@ import pkg_resources from .create_crew import create_crew +from .train_crew import train_crew @click.group() @@ -27,11 +28,25 @@ def version(tools): if tools: try: - tools_version = pkg_resources.get_distribution("crewai[tools]").version + tools_version = pkg_resources.get_distribution("crewai-tools").version click.echo(f"crewai tools version: {tools_version}") except pkg_resources.DistributionNotFound: click.echo("crewai tools not installed") +@crewai.command() +@click.option( + "-n", + "--n_iterations", + type=int, + default=5, + help="Number of iterations to train the crew", +) +def train(n_iterations: int): + """Train the crew.""" + click.echo(f"Training the crew for {n_iterations} iterations") + train_crew(n_iterations) + + if __name__ == "__main__": crewai() diff --git a/src/crewai/cli/templates/main.py b/src/crewai/cli/templates/main.py index 3aa0f35c05..469884a88f 100644 --- a/src/crewai/cli/templates/main.py +++ b/src/crewai/cli/templates/main.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +import sys from {{folder_name}}.crew import {{crew_name}}Crew @@ -7,4 +8,15 @@ def run(): inputs = { 'topic': 'AI LLMs' } - {{crew_name}}Crew().crew().kickoff(inputs=inputs) \ No newline at end of file + {{crew_name}}Crew().crew().kickoff(inputs=inputs) + + +def train(): + """ + Train the crew for a given number of iterations. + """ + try: + {{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1])) + + except Exception as e: + raise Exception(f"An error occurred while training the crew: {e}") diff --git a/src/crewai/cli/templates/pyproject.toml b/src/crewai/cli/templates/pyproject.toml index d5061ecbc9..7d898efe69 100644 --- a/src/crewai/cli/templates/pyproject.toml +++ b/src/crewai/cli/templates/pyproject.toml @@ -6,11 +6,12 @@ authors = ["Your Name "] [tool.poetry.dependencies] python = ">=3.10,<=3.13" -crewai = {extras = ["tools"], version = "^0.30.11"} +crewai = { extras = ["tools"], version = "^0.30.11" } [tool.poetry.scripts] {{folder_name}} = "{{folder_name}}.main:run" +train = "{{folder_name}}.main:train" [build-system] requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" \ No newline at end of file +build-backend = "poetry.core.masonry.api" diff --git a/src/crewai/cli/train_crew.py b/src/crewai/cli/train_crew.py index e69de29bb2..cd880db5d1 100644 --- a/src/crewai/cli/train_crew.py +++ b/src/crewai/cli/train_crew.py @@ -0,0 +1,29 @@ +import subprocess + +import click + + +def train_crew(n_iterations: int) -> None: + """ + Train the crew by running a command in the Poetry environment. + + Args: + n_iterations (int): The number of iterations to train the crew. + """ + command = ["poetry", "run", "train", str(n_iterations)] + + try: + if n_iterations <= 0: + raise ValueError("The number of iterations must be a positive integer.") + + result = subprocess.run(command, capture_output=False, text=True, check=True) + + if result.stderr: + click.echo(result.stderr, err=True) + + except subprocess.CalledProcessError as e: + click.echo(f"An error occurred while training the crew: {e}", err=True) + click.echo(e.output, err=True) + + except Exception as e: + click.echo(f"An unexpected error occurred: {e}", err=True) diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 6047435f92..dacc38e104 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -164,7 +164,9 @@ def create_crew_memory(self) -> "Crew": """Set private attributes.""" if self.memory: self._long_term_memory = LongTermMemory() - self._short_term_memory = ShortTermMemory(crew=self, embedder_config=self.embedder) + self._short_term_memory = ShortTermMemory( + crew=self, embedder_config=self.embedder + ) self._entity_memory = EntityMemory(crew=self, embedder_config=self.embedder) return self @@ -280,6 +282,10 @@ def kickoff(self, inputs: Optional[Dict[str, Any]] = {}) -> str: return result + def train(self, n_iterations: int) -> None: + # TODO: Implement training + pass + def _run_sequential_process(self) -> str: """Executes tasks sequentially and returns the final output.""" task_output = "" diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index 32e6721480..519d7a62a0 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -12,7 +12,10 @@ class EntityMemory(Memory): def __init__(self, crew=None, embedder_config=None): storage = RAGStorage( - type="entities", allow_reset=False, embedder_config=embedder_config, crew=crew + type="entities", + allow_reset=False, + embedder_config=embedder_config, + crew=crew, ) super().__init__(storage) diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index 30bf202c00..e9410ebbca 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -13,7 +13,9 @@ class ShortTermMemory(Memory): """ def __init__(self, crew=None, embedder_config=None): - storage = RAGStorage(type="short_term", embedder_config=embedder_config, crew=crew) + storage = RAGStorage( + type="short_term", embedder_config=embedder_config, crew=crew + ) super().__init__(storage) def save(self, item: ShortTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" diff --git a/src/crewai/project/crew_base.py b/src/crewai/project/crew_base.py index e58c377f72..154b094c7c 100644 --- a/src/crewai/project/crew_base.py +++ b/src/crewai/project/crew_base.py @@ -1,18 +1,18 @@ import inspect -import yaml import os - from pathlib import Path -from pydantic import ConfigDict +import yaml from dotenv import load_dotenv +from pydantic import ConfigDict + load_dotenv() def CrewBase(cls): class WrappedClass(cls): model_config = ConfigDict(arbitrary_types_allowed=True) - is_crew_class: bool = True + is_crew_class: bool = True # type: ignore base_directory = None for frame_info in inspect.stack(): diff --git a/src/crewai/task.py b/src/crewai/task.py index 1459e8ac68..715b09534f 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -306,7 +306,7 @@ def _save_file(self, result: Any) -> None: if directory and not os.path.exists(directory): os.makedirs(directory) - with open(self.output_file, "w", encoding='utf-8') as file: # type: ignore # Argument 1 to "open" has incompatible type "str | None"; expected "int | str | bytes | PathLike[str] | PathLike[bytes]" + with open(self.output_file, "w", encoding="utf-8") as file: # type: ignore # Argument 1 to "open" has incompatible type "str | None"; expected "int | str | bytes | PathLike[str] | PathLike[bytes]" file.write(result) return None diff --git a/src/crewai/tools/agent_tools.py b/src/crewai/tools/agent_tools.py index 293a1ca237..7598e9040a 100644 --- a/src/crewai/tools/agent_tools.py +++ b/src/crewai/tools/agent_tools.py @@ -33,20 +33,26 @@ def tools(self): ] return tools - def delegate_work(self, task: str, context: str, coworker: Union[str, None] = None, **kwargs): + def delegate_work( + self, task: str, context: str, coworker: Union[str, None] = None, **kwargs + ): """Useful to delegate a specific task to a co-worker passing all necessary context and names.""" coworker = coworker or kwargs.get("co_worker") or kwargs.get("co-worker") - is_list = coworker.startswith("[") and coworker.endswith("]") - if is_list: - coworker = coworker[1:-1].split(",")[0] + if coworker is not None: + is_list = coworker.startswith("[") and coworker.endswith("]") + if is_list: + coworker = coworker[1:-1].split(",")[0] return self._execute(coworker, task, context) - def ask_question(self, question: str, context: str, coworker: Union[str, None] = None, **kwargs): + def ask_question( + self, question: str, context: str, coworker: Union[str, None] = None, **kwargs + ): """Useful to ask a question, opinion or take from a co-worker passing all necessary context and names.""" coworker = coworker or kwargs.get("co_worker") or kwargs.get("co-worker") - is_list = coworker.startswith("[") and coworker.endswith("]") - if is_list: - coworker = coworker[1:-1].split(",")[0] + if coworker is not None: + is_list = coworker.startswith("[") and coworker.endswith("]") + if is_list: + coworker = coworker[1:-1].split(",")[0] return self._execute(coworker, question, context) def _execute(self, agent, task, context): diff --git a/tests/cli/cli_test.py b/tests/cli/cli_test.py new file mode 100644 index 0000000000..f2c6879b05 --- /dev/null +++ b/tests/cli/cli_test.py @@ -0,0 +1,59 @@ +from unittest import mock + +import pytest +from click.testing import CliRunner + +from crewai.cli.cli import train, version + + +@pytest.fixture +def runner(): + return CliRunner() + + +@mock.patch("crewai.cli.cli.train_crew") +def test_train_default_iterations(train_crew, runner): + result = runner.invoke(train) + + train_crew.assert_called_once_with(5) + assert result.exit_code == 0 + assert "Training the crew for 5 iterations" in result.output + + +@mock.patch("crewai.cli.cli.train_crew") +def test_train_custom_iterations(train_crew, runner): + result = runner.invoke(train, ["--n_iterations", "10"]) + + train_crew.assert_called_once_with(10) + assert result.exit_code == 0 + assert "Training the crew for 10 iterations" in result.output + + +@mock.patch("crewai.cli.cli.train_crew") +def test_train_invalid_string_iterations(train_crew, runner): + result = runner.invoke(train, ["--n_iterations", "invalid"]) + + train_crew.assert_not_called() + assert result.exit_code == 2 + assert ( + "Usage: train [OPTIONS]\nTry 'train --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n" + in result.output + ) + + +def test_version_command(runner): + result = runner.invoke(version) + + assert result.exit_code == 0 + assert "crewai version:" in result.output + + +def test_version_command_with_tools(runner): + result = runner.invoke(version, ["--tools"]) + + assert result.exit_code == 0 + assert "crewai version:" in result.output + assert ( + "crewai tools version:" in result.output + or "crewai tools not installed" in result.output + ) diff --git a/tests/cli/train_crew_test.py b/tests/cli/train_crew_test.py new file mode 100644 index 0000000000..9d0d3d4a73 --- /dev/null +++ b/tests/cli/train_crew_test.py @@ -0,0 +1,87 @@ +import subprocess +from unittest import mock + +from crewai.cli.train_crew import train_crew + + +@mock.patch("crewai.cli.train_crew.subprocess.run") +def test_train_crew_positive_iterations(mock_subprocess_run): + # Arrange + n_iterations = 5 + mock_subprocess_run.return_value = subprocess.CompletedProcess( + args=["poetry", "run", "train", str(n_iterations)], + returncode=0, + stdout="Success", + stderr="", + ) + + # Act + train_crew(n_iterations) + + # Assert + mock_subprocess_run.assert_called_once_with( + ["poetry", "run", "train", str(n_iterations)], + capture_output=False, + text=True, + check=True, + ) + + +@mock.patch("crewai.cli.train_crew.click") +def test_train_crew_zero_iterations(click): + train_crew(0) + click.echo.assert_called_once_with( + "An unexpected error occurred: The number of iterations must be a positive integer.", + err=True, + ) + + +@mock.patch("crewai.cli.train_crew.click") +def test_train_crew_negative_iterations(click): + train_crew(-2) + click.echo.assert_called_once_with( + "An unexpected error occurred: The number of iterations must be a positive integer.", + err=True, + ) + + +@mock.patch("crewai.cli.train_crew.click") +@mock.patch("crewai.cli.train_crew.subprocess.run") +def test_train_crew_called_process_error(mock_subprocess_run, click): + n_iterations = 5 + mock_subprocess_run.side_effect = subprocess.CalledProcessError( + returncode=1, + cmd=["poetry", "run", "train", str(n_iterations)], + output="Error", + stderr="Some error occurred", + ) + train_crew(n_iterations) + + mock_subprocess_run.assert_called_once_with( + ["poetry", "run", "train", "5"], capture_output=False, text=True, check=True + ) + click.echo.assert_has_calls( + [ + mock.call.echo( + "An error occurred while training the crew: Command '['poetry', 'run', 'train', '5']' returned non-zero exit status 1.", + err=True, + ), + mock.call.echo("Error", err=True), + ] + ) + + +@mock.patch("crewai.cli.train_crew.click") +@mock.patch("crewai.cli.train_crew.subprocess.run") +def test_train_crew_unexpected_exception(mock_subprocess_run, click): + # Arrange + n_iterations = 5 + mock_subprocess_run.side_effect = Exception("Unexpected error") + train_crew(n_iterations) + + mock_subprocess_run.assert_called_once_with( + ["poetry", "run", "train", "5"], capture_output=False, text=True, check=True + ) + click.echo.assert_called_once_with( + "An unexpected error occurred: Unexpected error", err=True + ) diff --git a/tests/cli_test.py b/tests/cli_test.py deleted file mode 100644 index 88247ec692..0000000000 --- a/tests/cli_test.py +++ /dev/null @@ -1,20 +0,0 @@ -from click.testing import CliRunner -from crewai.cli.cli import version - - -def test_version_command(): - runner = CliRunner() - result = runner.invoke(version) - assert result.exit_code == 0 - assert "crewai version:" in result.output - - -def test_version_command_with_tools(): - runner = CliRunner() - result = runner.invoke(version, ["--tools"]) - assert result.exit_code == 0 - assert "crewai version:" in result.output - assert ( - "crewai tools version:" in result.output - or "crewai tools not installed" in result.output - ) diff --git a/tests/crew_test.py b/tests/crew_test.py index b9923e2ea3..75868ac9e3 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -993,3 +993,35 @@ def testing_tool(first_number: int, second_number: int) -> int: with pytest.raises(Exception): crew.kickoff() + + +def test_crew_train_success(): + task = Task( + description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.", + expected_output="5 bullet points with a paragraph for each idea.", + ) + + crew = Crew( + agents=[researcher, writer], + tasks=[task], + ) + + crew.train(n_iterations=2) + + +def test_crew_train_error(): + task = Task( + description="Come up with a list of 5 interesting ideas to explore for an article", + expected_output="5 bullet points with a paragraph for each idea.", + ) + + crew = Crew( + agents=[researcher, writer], + tasks=[task], + ) + + with pytest.raises(TypeError) as e: + crew.train() + assert "train() missing 1 required positional argument: 'n_iterations'" in str( + e + )