From 4b4e3b51d23c41c311b88f9e968cab126e26bd22 Mon Sep 17 00:00:00 2001 From: Edoardo Baldi Date: Wed, 20 Nov 2024 15:19:44 +0100 Subject: [PATCH 1/2] Create docker-build.yml --- .github/workflows/docker-build.yml | 36 ++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 .github/workflows/docker-build.yml diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml new file mode 100644 index 00000000..e9e7e7b9 --- /dev/null +++ b/.github/workflows/docker-build.yml @@ -0,0 +1,36 @@ +name: Build Tutorial Container + +on: + push: + branches: + - main + paths-ignore: + - '*.md' + - slides/** + - images/** + - .gitignore + workflow_dispatch: + +jobs: + build-and-push: + runs-on: ubuntu-latest + permissions: + packages: write + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Log in to GHCR + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build the Docker image + run: | + docker build -t ghcr.io/${{ github.repository }}:latest . + + - name: Push the Docker image + run: | + docker push ghcr.io/${{ github.repository }}:latest From 2af358d75b5d2e0fddbb9ed8c494fb0d5ba9fbc7 Mon Sep 17 00:00:00 2001 From: Edoardo Baldi Date: Mon, 25 Nov 2024 21:48:01 +0100 Subject: [PATCH 2/2] Integrate AI feedback for tests results (#248) (#3) --- .gitignore | 1 + binder/environment.yml | 4 + intro.ipynb | 4 +- openai.env.example | 3 + tutorial/tests/testsuite/ai_helpers.py | 676 +++++++++++++++++++++ tutorial/tests/testsuite/exceptions.py | 46 ++ tutorial/tests/testsuite/helpers.py | 777 +++++++++++++++++++------ tutorial/tests/testsuite/testsuite.py | 177 ++++-- 8 files changed, 1448 insertions(+), 240 deletions(-) create mode 100644 openai.env.example create mode 100644 tutorial/tests/testsuite/ai_helpers.py diff --git a/.gitignore b/.gitignore index bbdff5c2..5d071089 100644 --- a/.gitignore +++ b/.gitignore @@ -109,6 +109,7 @@ venv/ ENV/ env.bak/ venv.bak/ +openai.env # Spyder project settings .spyderproject diff --git a/binder/environment.yml b/binder/environment.yml index 7d382e1d..60d0322a 100644 --- a/binder/environment.yml +++ b/binder/environment.yml @@ -21,3 +21,7 @@ dependencies: - scikit-learn - attrs - multiprocess + - openai + - tenacity + - markdown2 + - python-dotenv diff --git a/intro.ipynb b/intro.ipynb index 2b9762d6..3d71321c 100644 --- a/intro.ipynb +++ b/intro.ipynb @@ -372,7 +372,7 @@ "%%ipytest\n", "\n", "def solution_sum_two_numbers(a: int, b: int) -> int:\n", - " wrong_sum = a - b\n", + " wrong_sum = a - b -\n", " return wrong_sum" ] }, @@ -469,7 +469,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.10.15" }, "vscode": { "interpreter": { diff --git a/openai.env.example b/openai.env.example new file mode 100644 index 00000000..618a6c05 --- /dev/null +++ b/openai.env.example @@ -0,0 +1,3 @@ +OPENAI_API_KEY="sk-**********" # your OpenAI API key +OPENAI_MODEL="gpt-4o-mini" # the model you want to use +OPENAI_LANGUAGE="English" # the language you want to use diff --git a/tutorial/tests/testsuite/ai_helpers.py b/tutorial/tests/testsuite/ai_helpers.py new file mode 100644 index 00000000..fce40504 --- /dev/null +++ b/tutorial/tests/testsuite/ai_helpers.py @@ -0,0 +1,676 @@ +import logging +import traceback +import typing as t +from enum import Enum +from threading import Timer + +import ipywidgets as widgets +import markdown2 as md +import openai +from IPython.display import Code, display, display_html +from openai.types.chat import ( + ChatCompletionMessage, + ChatCompletionMessageParam, + ParsedChatCompletionMessage, +) +from pydantic import BaseModel +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_fixed, + wait_random, +) + +from .exceptions import ( + APIConnectionError, + InvalidAPIKeyError, + InvalidModelError, + UnexpectedAPIError, + ValidationResult, +) + +if t.TYPE_CHECKING: + from .helpers import IPytestResult + +# Set logger +logger = logging.getLogger() + + +class ExplanationStep(BaseModel): + """A single step in the explanation""" + + title: t.Optional[str] + content: str + + +class CodeSnippet(BaseModel): + """A code snippet with optional description""" + + code: str + description: t.Optional[str] + + +class Explanation(BaseModel): + """A structured explanation with steps, code snippets, and hints""" + + summary: str + steps: t.List[ExplanationStep] + code_snippets: t.List[CodeSnippet] + hints: t.List[str] + + +class OpenAIWrapper: + """A simple API wrapper adapted for IPython environments""" + + # These are the models we can use: they must support structured responses + GPT_STABLE_MODELS = ("gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4o-mini") + GPT_ALL_MODELS = GPT_STABLE_MODELS + + DEFAULT_MODEL = "gpt-4o-mini" + DEFAULT_LANGUAGE = "English" + + _instance = None + + def __new__(cls, *args, **kwargs) -> "OpenAIWrapper": + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + @classmethod + def create_validated( + cls, + api_key: str, + model: t.Optional[str] = None, + language: t.Optional[str] = None, + ) -> t.Tuple["OpenAIWrapper", ValidationResult]: + instance = cls.__new__(cls) + + # Only initialize if not already + if not hasattr(instance, "client"): + instance.api_key = api_key + instance.language = language or cls.DEFAULT_LANGUAGE + instance.model = model or cls.DEFAULT_MODEL + instance.client = openai.OpenAI(api_key=api_key) + + # Validate the model + model_validation = instance.validate_model(instance.model) + return instance, model_validation + + @classmethod + def validate_api_key(cls, api_key: t.Optional[str]) -> ValidationResult: + """Validate the OpenAI API key""" + if not api_key: + return ValidationResult( + is_valid=False, + error=InvalidAPIKeyError("API key is missing."), + message="OpenAI API key is not provided.", + ) + + try: + client = openai.OpenAI(api_key=api_key) + client.models.list() # the simplest API call to verify the API + except openai.AuthenticationError: + return ValidationResult( + is_valid=False, + error=InvalidAPIKeyError("The provided API key is invalid."), + message="Invalid OpenAI API key. Please, double check it.", + ) + except openai.APIConnectionError: + return ValidationResult( + is_valid=False, + error=APIConnectionError("Unable to connect to OpenAI."), + message="Could not connect to OpenAI. Please, check your internet connection.", + ) + except Exception as e: + return ValidationResult( + is_valid=False, + error=UnexpectedAPIError(f"Unexpected error: {e}"), + message="An unexpected error occurred while validating API key.", + ) + else: + return ValidationResult(is_valid=True) + + def validate_model(self, model: t.Optional[str]) -> ValidationResult: + """Validate the model selection""" + try: + if model not in self.GPT_ALL_MODELS: + return ValidationResult( + is_valid=False, + error=InvalidModelError(), + message=f"Invalid model: {model}. Available models: {' '.join(self.GPT_ALL_MODELS)}", + ) + except Exception as e: + return ValidationResult( + is_valid=False, + error=UnexpectedAPIError(f"Error validating model: {e}"), + message="Unexpected error during model validation", + ) + + return ValidationResult(is_valid=True) + + def __init__( + self, + api_key: t.Optional[str], + model: t.Optional[str] = None, + language: t.Optional[str] = None, + ) -> None: + """Initialize the wrapper for OpenAI API with logging and checks""" + # Avoid reinitializing the client + if hasattr(self, "client"): + return + + # Validate the API key + validation = self.validate_api_key(api_key) + if not validation.is_valid: + assert validation.error is not None # for type checking + raise validation.error + + self.api_key = api_key + self.language = language or self.DEFAULT_LANGUAGE + self.client = openai.OpenAI(api_key=self.api_key) + + self.model = model or self.DEFAULT_MODEL + model_validation = self.validate_model(self.model) + if not model_validation.is_valid: + assert model_validation.error is not None # type checking + raise model_validation.error + + def change_model(self, model: str) -> None: + """Change the active OpenAI model in use""" + validation = self.validate_model(model) + if not validation.is_valid: + assert validation.error is not None # type checking + logger.exception("Error changing model") + raise validation.error + + self.model = model + logger.info("Model changed to %s", self.model) + + @retry( + retry=retry_if_exception_type(openai.RateLimitError), + stop=stop_after_attempt(3), + wait=wait_fixed(10) + wait_random(0, 5), + ) + def get_chat_response( + self, query: str, *args, **kwargs + ) -> ParsedChatCompletionMessage | ChatCompletionMessage: + """Fetch a completion from the chat model""" + system_prompt = ( + "As an expert Python developer, provide clear and concise explanations of error tracebacks, " + "focusing on the root cause for users with minimal Python experience. " + "Follow these guidelines strictly:\n" + "- Offer hints, even for trivial errors.\n" + "- Take into account the number of attempts made by providing increasingly detailed hints after a failed attempt.\n" + "- Do not provide verbatim solutions, only guidance.\n" + f"- Respond in {self.language}.\n" + "- Any text or string must be written in Markdown." + ) + + messages: t.List[ChatCompletionMessageParam] = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": query}, + ] + + try: + response = self.client.beta.chat.completions.parse( + model=self.model, + messages=messages, + response_format=Explanation, + **kwargs, + ) + except openai.APIError: + logger.exception("API error encountered.") + raise + except openai.LengthFinishReasonError: + logger.exception("Input prompt has too many tokens.") + raise + else: + return response.choices[0].message + + +class ButtonState(Enum): + """The state of the explanation button""" + + READY = "ready" + LOADING = "loading" + WAIT = "waiting" + + +class AIExplanation: + """Class representing an AI-generated explanation""" + + _STYLES = """ + + """ + + def __init__( + self, + ipytest_result: "IPytestResult", + openai_client: "OpenAIWrapper", + exception: t.Optional[BaseException] = None, + wait_time: int = 60, # Wait time in seconds + ) -> None: + """Public constructor for an explanation widget""" + self.ipytest_result = ipytest_result + self.exception = exception + self.openai_client = openai_client + + # The output widget for displaying the explanation + self._output = widgets.Output() + + # Timer and state + self._timer: t.Optional[Timer] = None + self._is_throttled = False + self._wait_time = float(wait_time) + self._remaining_time = float(wait_time) + + # The button widget for fetching the explanation + self._button_styles = { + ButtonState.READY: { + "description": "Get AI Explanation", + "icon": "search", + "disabled": False, + }, + ButtonState.LOADING: { + "description": "Loading...", + "icon": "spinner", + "disabled": True, + }, + ButtonState.WAIT: { + "description": "Please wait", + "icon": "hourglass-start", + "disabled": True, + }, + } + + # Set a default query + self._query_template = ( + "I wrote the following Python function:\n\n" + "{function_code}\n\n" + "Whose docstring describes the purpose, arguments, and expected return values:\n\n" + "{docstring}\n\n" + "Running pytest on this function failed, and here is the error traceback I got:\n\n" + "{traceback}\n\n" + "Consider that this is my {attempt_number} attempt." + ) + + # Default query params + self._query_params = { + "function_code": "", + "docstring": "", + "traceback": "", + "attempt_number": self.ipytest_result.test_attempts, + } + + # Create a header with button and timer + self._header = widgets.Box( + layout=widgets.Layout( + display="flex", + align_items="center", + margin="0 0 1rem 0", + ) + ) + + # Create a timer display + self._timer_display = widgets.HTML( + value="", layout=widgets.Layout(margin="0 0 0 0.75rem") + ) + + # Initialize the button + self._current_state = ButtonState.READY + self._button = widgets.Button() + self._update_button_state(ButtonState.READY) + self._button.on_click(self._handle_click) + self._button.observe(self._handle_state_change, names=["disabled"]) + + def render(self) -> widgets.Widget: + """Return a single widget containing all the components""" + style_html = widgets.HTML(self._STYLES) + + header_html = widgets.HTML( + '
' + '🤖 Explain With AI' + "
" + ) + + button_container = widgets.Box( + [ + self._button, + self._timer_display, + ], + layout=widgets.Layout(display="flex", align_iterms="center"), + ) + + # Create the rendered container + container = widgets.VBox( + children=[ + style_html, + header_html, + button_container, + self._output, + ], + layout=widgets.Layout(margin="1rem 0 0 0", padding="0"), + ) + + return container + + def set_query_template(self, template: str) -> None: + """Set the query template""" + self._query_template = template + + def query_params(self, *args, **kwargs: t.Any) -> None: + """Add/update multiple query parameters""" + self._query_params.update(kwargs) + + @property + def query(self) -> str: + """Generate the query string""" + logger.debug("Building a query with parameters: %s", self._query_params) + try: + return self._query_template.format(**self._query_params) + except KeyError as e: + logger.exception("Missing key in query parameter") + raise ValueError from e + + def _update_remaining_time(self): + """Update the button label with remaining time""" + self._remaining_time = max(0, self._remaining_time - 1) + if self._is_throttled: + self._update_button_state(ButtonState.WAIT) + + def _update_button_state(self, state: ButtonState) -> None: + """Update the button state""" + self._current_state = state + style = self._button_styles[state].copy() + + # Update the timer display + if state == ButtonState.WAIT: + self._timer_display.value = ( + 'Available in ' + f"{int(self._remaining_time)} " + "seconds" + ) + else: + self._timer_display.value = "" + + self._button.add_class("ai-button") + self._button.description = style["description"] + self._button.icon = style["icon"] + self._button.disabled = style["disabled"] + + def _handle_state_change(self, change): + """Handle the state change of the button""" + if change["new"]: + if self._is_throttled: + self._update_button_state(ButtonState.WAIT) + else: + self._is_throttled = False + self._update_button_state(ButtonState.READY) + + def _enable_button(self): + """Enable the button after a delay""" + self._button.disabled = False + self._timer = None + self._remaining_time = self._wait_time + + def _handle_click(self, _) -> None: + """Handle the button click event with throttling""" + if self._is_throttled: + self._update_button_state(ButtonState.WAIT) + return + + self._is_throttled = True + self._button.disabled = True + self._remaining_time = self._wait_time + + def update_timer(): + if self._remaining_time > 0: + self._update_remaining_time() + self._timer = Timer(1, update_timer) + self._timer.start() + else: + self._enable_button() + + self._timer = Timer(1.0, update_timer) + self._timer.start() + + # Call the method to fetch the explanation + self._fetch_explanation() + + def _fetch_explanation(self) -> None: + """Fetch the explanation from OpenAI API""" + from .helpers import IPytestOutcome + + logger.debug("Attempting to fetch explanation from OpenAI API.") + + if not self.openai_client: + return + + self._update_button_state(ButtonState.LOADING) + + if self.exception: + traceback_str = "".join(traceback.format_exception_only(self.exception)) + logger.debug("Formatted traceback: %s", traceback_str) + else: + traceback_str = "No traceback available." + + with self._output: + self._output.clear_output() + + try: + # assert self.ipytest_result.function is not None + match self.ipytest_result.status: + case IPytestOutcome.FINISHED if self.ipytest_result.function is not None: + self.query_params( + function_code=self.ipytest_result.function.source_code, + docstring=self.ipytest_result.function.implementation.__doc__, + traceback=traceback_str, + ) + case _: + self.query_params( + function_code=self.ipytest_result.cell_content, + docstring="(Find it in the function's definition above.)", + traceback=traceback_str, + ) + + response = self.openai_client.get_chat_response( + self.query, + temperature=0.2, + ) + + logger.debug("Received response: %s", response) + + formatted_response = self._format_explanation(response) + + logger.debug("Formatted response: %s", formatted_response) + + if formatted_response: + display(widgets.VBox(children=formatted_response)) + else: + display(widgets.HTML("

No explanation could be generated.

")) + except Exception as e: + logger.exception("An error occurred while fetching the explanation.") + display_html(f"

Failed to fetch explanation: {e}

", raw=True) + finally: + if self._is_throttled: + self._update_button_state(ButtonState.WAIT) + else: + self._update_button_state(ButtonState.READY) + + def _format_explanation( + self, chat_response: ParsedChatCompletionMessage | ChatCompletionMessage + ) -> t.Optional[t.List[t.Any]]: + """Format the explanation response for display""" + + # Initialize the Markdown to HTML converter + def to_html(text: t.Any) -> str: + """Markdown to HTML converter""" + return md.markdown(str(text)) + + # Reset the explanation object + explanation = None + + # A list to store all the widgets + widgets_list = [] + + if ( + isinstance(chat_response, ParsedChatCompletionMessage) + and (explanation := chat_response.parsed) is not None + ): + logger.debug("Response is a valid `Explanation` object that can be parsed.") + + # A summary of the explanation + summary_widget = widgets.HTML(f"

{to_html(explanation.summary)}

") + widgets_list.append(summary_widget) + + # Add steps as Accordion widgets + steps_widgets = [] + for i, step in enumerate(explanation.steps, start=1): + step_title = step.title or f"Step {i}" + step_content = widgets.HTML(to_html(step.content)) + step_accordion = widgets.Accordion( + children=[step_content], titles=(step_title,) + ) + steps_widgets.append(step_accordion) + + widgets_list.extend(steps_widgets) + + # Add code snippets using Code widgets + if explanation.code_snippets: + for i, snippet in enumerate(explanation.code_snippets, start=1): + snippet_output = widgets.Output() + snippet_description = widgets.HTML(to_html(snippet.description)) + snippet_output.append_display_data(snippet_description) + + snippet_code = Code(language="python", data=snippet.code) + snippet_output.append_display_data(snippet_code) + + snippet_accordion = widgets.Accordion( + children=[snippet_output], titles=(f"Code Snippet #{i}",) + ) + + widgets_list.append(snippet_accordion) + + # Add hints as bullet points + if explanation.hints: + hints_html = ( + "" + ) + hints_widget = widgets.Accordion( + children=[widgets.HTML(hints_html)], + titles=("Hints",), + ) + widgets_list.append(hints_widget) + + elif ( + isinstance(chat_response, ChatCompletionMessage) + and (explanation := chat_response.content) is not None + ): + logger.debug( + "Response is not a structured `Explanation` object, returning as-is." + ) + explanation = ( + explanation.removeprefix("```html").removesuffix("```").strip() + ) + + widgets_list.append(widgets.HTML(to_html(explanation))) + + if explanation is not None: + # Wrap everything in a styled container + container = widgets.VBox( + children=[widgets.HTML('
')] + + widgets_list + + [widgets.HTML("
")] + ) + return [container] + else: + logger.debug("Failed to parse explanation.") + + return None diff --git a/tutorial/tests/testsuite/exceptions.py b/tutorial/tests/testsuite/exceptions.py index 11e2bfc1..a32769bf 100644 --- a/tutorial/tests/testsuite/exceptions.py +++ b/tutorial/tests/testsuite/exceptions.py @@ -1,3 +1,6 @@ +from dataclasses import dataclass + + class FunctionNotFoundError(Exception): """Custom exception raised when the solution code cannot be parsed""" @@ -17,3 +20,46 @@ class TestModuleNotFoundError(Exception): def __init__(self) -> None: super().__init__("Test module is not defined") + + +class PytestInternalError(Exception): + """Custom exception raised when the test module cannot be found""" + + def __init__(self) -> None: + super().__init__("Pytest internal error") + + +class OpenAIWrapperError(Exception): + """Base exception for OpenAI validation errors""" + + +class InvalidAPIKeyError(OpenAIWrapperError): + """Invalid API key""" + + +class APIConnectionError(OpenAIWrapperError): + """Connection error""" + + +class UnexpectedAPIError(OpenAIWrapperError): + """Unexpected API error""" + + +class InvalidModelError(OpenAIWrapperError): + """Invalid model selection""" + + +@dataclass +class ValidationResult: + """Result of OpenAI wrapper validation""" + + is_valid: bool + error: OpenAIWrapperError | None = None + message: str = "" + + @property + def user_message(self) -> str: + """Get a user-friendly message""" + if self.error is not None: + return f"🚫 {self.message}
{str(self.error)}" + return "✅ OpenAI client configured successfully." diff --git a/tutorial/tests/testsuite/helpers.py b/tutorial/tests/testsuite/helpers.py index aab40e4e..21607bd2 100644 --- a/tutorial/tests/testsuite/helpers.py +++ b/tutorial/tests/testsuite/helpers.py @@ -1,10 +1,9 @@ import html -import re -import traceback from dataclasses import dataclass from enum import Enum +from pathlib import Path from types import TracebackType -from typing import Callable, ClassVar, Dict, List, Optional +from typing import Any, Callable, ClassVar, Dict, List, Optional import ipywidgets import pytest @@ -12,6 +11,8 @@ from IPython.display import display as ipython_display from ipywidgets import HTML +from .ai_helpers import AIExplanation, OpenAIWrapper + class TestOutcome(Enum): PASS = 1 @@ -28,77 +29,341 @@ class IPytestOutcome(Enum): UNKNOWN_ERROR = 5 +@dataclass +class DebugOutput: + """Class to format debug information about test execution""" + + module_name: str + module_file: Path + results: List["IPytestResult"] + + def to_html(self) -> str: + """Format debug information as HTML""" + debug_parts = [ + """ + +
+ """ + ] + + # Overall test run info + debug_parts.append('
Debug Information
') + debug_parts.append( + '
' + f"Module: {self.module_name}
" + f"Module file: {self.module_file}
" + f"Number of results: {len(self.results)}" + "
" + ) + + # Detailed results + for i, result in enumerate(self.results, 1): + debug_parts.append( + f'
' + f'Result #{i}
' + f'Status: {result.status.name if result.status else "None"}
' + f'Function: {result.function.name if result.function else "None"}
' + f'Solution attempts: {result.test_attempts}' + ) + + if result.test_results: + debug_parts.append( + f'
' + f"Test Results ({len(result.test_results)}):" + '
' + ) + for test in result.test_results: + debug_parts.append( + f'• {test.test_name}: {test.outcome.name}' + f'{f" - {type(test.exception).__name__}: {str(test.exception)}" if test.exception else ""}
' + ) + debug_parts.append("
") + + if result.exceptions: + debug_parts.append( + f'
' + f"Exceptions ({len(result.exceptions)}):" + '
' + ) + for exc in result.exceptions: + debug_parts.append(f"• {type(exc).__name__}: {str(exc)}
") + debug_parts.append("
") + + debug_parts.append("
") + + debug_parts.append("
") + + return "\n".join(debug_parts) + + @dataclass class TestCaseResult: """Container class to store the test results when we collect them""" test_name: str outcome: TestOutcome - exception: BaseException | None - traceback: TracebackType | None + exception: Optional[BaseException] = None + traceback: Optional[TracebackType] = None + formatted_exception: str = "" stdout: str = "" stderr: str = "" + report_output: str = "" + + def __str__(self) -> str: + """Basic string representation""" + return ( + f"TestCaseResult(\n" + f" test_name: {self.test_name}\n" + f" outcome: {self.outcome.name if self.outcome else 'None'}\n" + f" exception: {type(self.exception).__name__ if self.exception else 'None'}" + f" - {str(self.exception) if self.exception else ''}\n" + f" formatted_exception: {self.formatted_exception[:100]}..." + f" ({len(self.formatted_exception)} chars)\n" + f" stdout: {len(self.stdout)} chars\n" + f" stderr: {len(self.stderr)} chars\n" + f" report_output: {len(self.report_output)} chars\n" + ")" + ) + def to_html(self) -> str: + """HTML representation of the test result""" + # CSS styles for the output + styles = """ + + """ + + # Determine test status and icon + match self.outcome: + case TestOutcome.PASS: + status_class = "test-pass" + icon = "✅" + status_text = "Passed" + case TestOutcome.FAIL: + status_class = "test-fail" + icon = "❌" + status_text = "Failed" + case TestOutcome.TEST_ERROR: + status_class = "test-error" + icon = "🚨" + status_text = "Syntax Error" + case _: + status_class = "test-error" + icon = "⚠️" + status_text = "Error" + + # Start building the HTML content + test_name = self.test_name.split("::")[-1] + html_parts = [styles] + + # Main container + html_parts.append( + f""" +
+
+ {icon} + {f'{html.escape(test_name)}' if test_name else ''} + {html.escape(status_text)} +
+ """ + ) -@dataclass -class IPytestResult: - function_name: Optional[str] = None - status: Optional[IPytestOutcome] = None - test_results: Optional[List[TestCaseResult]] = None - exceptions: Optional[List[BaseException]] = None - test_attempts: int = 0 + # Exception information if test failed + if self.exception is not None: + exception_type = type(self.exception).__name__ + exception_message = str(self.exception) + + html_parts.append( + f""" +
+
{html.escape(exception_type)}
+
{html.escape(exception_message)}
+
+ """ + ) + # Output sections (if any) + if self.stdout or self.stderr: + # Generate unique IDs for this test's tabs + tab_id = f"test_{hash(self.test_name)}" + html_parts.append( + f""" +
+
+ + +
+
+
+
{html.escape(self.stdout) if self.stdout else 'No output'}
+
+
+
{html.escape(self.stderr) if self.stderr else 'No errors'}
+
+
+
+ """ + ) -def format_error(exception: BaseException) -> str: - """ - Takes the output of traceback.format_exception_only() for an AssertionError - and returns a formatted string with clear, structured information. - """ - formatted_message = None + # Close main div + html_parts.append("
") - # Get a string representation of the exception, without the traceback - exception_str = "".join(traceback.format_exception_only(exception)) + return "\n".join(html_parts) - # Handle the case where we were expecting an exception but none was raised - if "DID NOT RAISE" in exception_str: - pattern = r"" - match = re.search(pattern, exception_str) - if match: - formatted_message = ( - "

Expected exception:

" - f"

Exception {html.escape(match.group(1))} was not raised.

" - ) - else: - # Regex pattern to extract relevant parts of the assertion message - pattern = ( - r"(\w+): assert (.*?) == (.*?)\n \+ where .*? = (.*?)\n \+ and .*? = (.*)" - ) - match = re.search(pattern, exception_str) - - if match: - ( - assertion_type, - actual_value, - expected_value, - actual_expression, - expected_expression, - ) = (html.escape(m) for m in match.groups()) - - # Formatting the output as HTML - formatted_message = ( - f"

{assertion_type}:

" - "
    " - f"
  • Failed Assertion: {actual_value} == {expected_value}
  • " - f"
  • Actual Value: {actual_value} obtained from {actual_expression}
  • " - f"
  • Expected Value: {expected_value} obtained from {expected_expression}
  • " - "
" - ) +@dataclass +class AFunction: + """Container class to store a function and its metadata""" - # If we couldn't parse the exception message, just display it as is - formatted_message = formatted_message or f"

{exception_str}

" + name: str + implementation: Callable[..., Any] + source_code: Optional[str] - return formatted_message + +@dataclass +class IPytestResult: + """Class to store the results of running pytest on a solution function""" + + function: Optional[AFunction] = None + status: Optional[IPytestOutcome] = None + test_results: Optional[List[TestCaseResult]] = None + exceptions: Optional[List[BaseException]] = None + test_attempts: int = 0 + cell_content: Optional[str] = None @dataclass @@ -108,6 +373,7 @@ class TestResultOutput: ipytest_result: IPytestResult solution: Optional[str] = None MAX_ATTEMPTS: ClassVar[int] = 3 + openai_client: Optional[OpenAIWrapper] = None def display_results(self) -> None: """Display the test results in an output widget as a VBox""" @@ -128,62 +394,223 @@ def display_results(self) -> None: else False ) - if success or self.ipytest_result.test_attempts > 2: + if success or self.ipytest_result.test_attempts >= self.MAX_ATTEMPTS: cells.append(solution_cell) else: if tests_finished: - cells.append( - HTML( - "

📝 A proposed solution will appear after " - f"{TestResultOutput.MAX_ATTEMPTS - self.ipytest_result.test_attempts} " - f"more failed attempt{'s' if self.ipytest_result.test_attempts < 2 else ''}.

", - ) + attempts_remaining = ( + self.MAX_ATTEMPTS - self.ipytest_result.test_attempts ) - else: cells.append( HTML( - "

⚠️ Your code could not run because of an error. Please, double-check it.

" + '
' + f'
' + '📝' + 'Solution will be available after ' + f'{attempts_remaining} more failed attempt{"s" if attempts_remaining > 1 else ""}' + "
" + "
" ) ) ipython_display( ipywidgets.VBox( children=cells, - # CSS: "border: 1px solid; border-color: lightgray; background-color: #FAFAFA; margin: 5px; padding: 10px;" layout={ - "border": "1px solid lightgray", - "background-color": "#FAFAFA", + "border": "1px solid #e5e7eb", + "background-color": "#ffffff", "margin": "5px", - "padding": "10px", + "padding": "0.75rem", + "border-radius": "0.5rem", }, ) ) - def prepare_solution_cell(self) -> ipywidgets.Widget: - """Prepare the cell to display the solution code""" - solution_code = ipywidgets.Output() + # TODO: This is left for reference if we ever want to bring back this styling + # Perhaps we should remove it if it's unnecessary + def __prepare_solution_cell(self) -> ipywidgets.Widget: + """Prepare the cell to display the solution code with a redacted effect until revealed""" + # Generate a unique ID for each solution cell + uuid = f"solution_{id(self)}" + + styles = """ + + """ + solution_cell = ipywidgets.Output() - solution_cell.append_display_data(HTML("

👉 Proposed solution:

")) + # Return an empty output widget if no solution is provided + if self.solution is None: + return solution_cell - solution_code.append_display_data( - Code(language="python", data=f"{self.solution}") + # Solution cell with redacted effect + solution_cell.append_display_data( + HTML( + f""" + {styles} +
+
+ 👉 + Proposed solution +
+
+
+
+ {Code(data=self.solution, language="python")._repr_html_()} +
+
+ +
+
+
+
+ """ + ) ) - solution_accordion = ipywidgets.Accordion( - titles=("Click here to reveal",), children=[solution_code] + return solution_cell + + def prepare_solution_cell(self) -> ipywidgets.Widget: + """Prepare the cell to display the solution code with a collapsible accordion""" + # Return an empty output widget if no solution is provided + if self.solution is None: + return ipywidgets.Output() + + # Create the solution content + solution_output = ipywidgets.Output( + layout=ipywidgets.Layout(padding="1rem", border="1px solid #e5e7eb") ) + with solution_output: + ipython_display(Code(data=self.solution, language="python")) + + # Create header with emoji + header_output = ipywidgets.Output() + with header_output: + ipython_display( + HTML( + '
' + '👉' + 'Proposed solution' + "
" + ) + ) - solution_cell.append_display_data(ipywidgets.Box(children=[solution_accordion])) + # Create the collapsible accordion (closed by default) + accordion = ipywidgets.Accordion( + children=[solution_output], + selected_index=None, # Start collapsed + titles=("View solution",), + layout=ipywidgets.Layout( + margin="1.5rem 0 0 0", + border="1px solid #e5e7eb", + border_radius="0.5rem", + ), + ) - return solution_cell + return ipywidgets.VBox( + children=[header_output, accordion], + layout=ipywidgets.Layout( + margin="0", + padding="0", + ), + ) def prepare_output_cell(self) -> ipywidgets.Output: """Prepare the cell to display the test results""" output_cell = ipywidgets.Output() + + # Header with test function name + function = self.ipytest_result.function + title = "Test Results for " if function else "Test Results " output_cell.append_display_data( HTML( - f'

Test Results for solution_{self.ipytest_result.function_name}

' + "
" + f'

{title}' + '' + f"solution_{function.name}

" + if function is not None + else f'

{title}

' "
" ) ) @@ -195,120 +622,80 @@ def prepare_output_cell(self) -> ipywidgets.Output: ): # We know that there is exactly one exception assert self.ipytest_result.exceptions is not None + # We know that there is no test results + assert self.ipytest_result.test_results is None + exception = self.ipytest_result.exceptions[0] - exceptions_str = ( - format_error(exception) if self.ipytest_result.exceptions else "" + + # Create a TestCaseResult for consistency + error_result = TestCaseResult( + test_name=f"error::solution_{function.name}" if function else "::", + outcome=TestOutcome.TEST_ERROR, + exception=exception, ) - output_cell.append_display_data( - ipywidgets.VBox( - children=[ - HTML(f"

{type(exception).__name__}

"), - HTML(exceptions_str), - ] + + output_cell.append_display_data(HTML(error_result.to_html())) + + if self.openai_client: + ai_explains = AIExplanation( + ipytest_result=self.ipytest_result, + exception=exception, + openai_client=self.openai_client, ) - ) - case IPytestOutcome.SOLUTION_FUNCTION_MISSING: - output_cell.append_display_data( - HTML("

Solution Function Missing

") - ) + output_cell.append_display_data(ai_explains.render()) case IPytestOutcome.FINISHED if self.ipytest_result.test_results: - captures: Dict[str, Dict[str, str]] = {} - - for test in self.ipytest_result.test_results: - captures[test.test_name.split("::")[-1]] = { - "stdout": test.stdout, - "stderr": test.stderr, - } - - # Create lists of HTML outs and errs - outs = [ - f"

{test_name}


{captures[test_name]['stdout']}" - for test_name in captures - if captures[test_name]["stdout"] - ] - errs = [ - f"

{test_name}


{captures[test_name]['stderr']}" - for test_name in captures - if captures[test_name]["stderr"] - ] + # Calculate test statistics + total_tests = len(self.ipytest_result.test_results) + passed_tests = sum( + 1 + for test in self.ipytest_result.test_results + if test.outcome == TestOutcome.PASS + ) + failed_tests = total_tests - passed_tests + # Display summary output_cell.append_display_data( - ipywidgets.VBox( - children=( - ipywidgets.Accordion( - children=( - ipywidgets.VBox( - children=[ - HTML(o, style={"background": "#FAFAFA"}) - for o in outs - ] - ), - ), - titles=("Captured output",), - ), - ipywidgets.Accordion( - children=( - ipywidgets.VBox( - children=[ - HTML(e, style={"background": "#FAFAFA"}) - for e in errs - ] - ), - ), - titles=("Captured error",), - ), - ) + HTML( + '
' + f'
' + f"✅ {passed_tests}/{total_tests} tests passed
" + f'
' + f"❌ {failed_tests}/{total_tests} tests failed
" + "
" ) ) - success = all( - test.outcome == TestOutcome.PASS + # Display individual test results + for test in self.ipytest_result.test_results: + output_cell.append_display_data(HTML(test.to_html())) + + failed_tests = [ + test for test in self.ipytest_result.test_results - ) + if test.outcome != TestOutcome.PASS + ] - num_results = len(self.ipytest_result.test_results) + if self.openai_client and failed_tests: + ai_explains = AIExplanation( + ipytest_result=self.ipytest_result, + exception=failed_tests[0].exception, + openai_client=self.openai_client, + ) + + output_cell.append_display_data(ai_explains.render()) + case IPytestOutcome.SOLUTION_FUNCTION_MISSING: output_cell.append_display_data( HTML( - f"

👉 We ran {num_results} test{'s' if num_results > 1 else ''}. " - f"""{"All tests passed!

" if success else "Below you find the details for each test run:"}""" + '
' + '
Solution Function Missing
' + "

Please implement the required solution function.

" + "
" ) ) - if not success: - for result in self.ipytest_result.test_results: - test_succeded = result.outcome == TestOutcome.PASS - test_name = result.test_name.split("::")[-1] - - output_box_children: List[ipywidgets.Widget] = [ - HTML( - f'

{"✔" if test_succeded else "❌"} Test {test_name}

', - style={ - "background": ( - "rgba(251, 59, 59, 0.25)" - if not test_succeded - else "rgba(207, 249, 179, 0.60)" - ) - }, - ) - ] - - if not test_succeded: - assert result.exception is not None - - output_box_children.append( - ipywidgets.Accordion( - children=[HTML(format_error(result.exception))], - titles=("Test results",), - ) - ) - - output_cell.append_display_data( - ipywidgets.VBox(children=output_box_children) - ) - case IPytestOutcome.NO_TEST_FOUND: output_cell.append_display_data(HTML("

No Test Found

")) @@ -341,25 +728,26 @@ def __init__(self) -> None: def pytest_runtest_makereport(self, item: pytest.Item, call: pytest.CallInfo): """Called when an individual test item has finished execution.""" if call.when == "call": - if call.excinfo is None: - # Test passes - self.tests[item.nodeid] = TestCaseResult( - test_name=item.nodeid, - outcome=TestOutcome.PASS, - stdout=call.result, - stderr=call.result, - exception=None, - traceback=None, - ) - else: - # Test fails - self.tests[item.nodeid] = TestCaseResult( - test_name=item.nodeid, - outcome=TestOutcome.FAIL, - exception=call.excinfo.value, - traceback=call.excinfo.tb, + test_result = TestCaseResult( + test_name=item.nodeid, + outcome=TestOutcome.FAIL if call.excinfo else TestOutcome.PASS, + exception=call.excinfo.value if call.excinfo else None, + traceback=call.excinfo.tb if call.excinfo else None, + ) + + if call.excinfo: + test_result.formatted_exception = str( + call.excinfo.getrepr( + showlocals=True, + style="long", + funcargs=True, + abspath=False, + chain=True, + ) ) + self.tests[item.nodeid] = test_result + def pytest_exception_interact( self, call: pytest.CallInfo, report: pytest.TestReport ): @@ -385,3 +773,6 @@ def pytest_runtest_logreport(self, report: pytest.TestReport): if test_result := self.tests.get(report.nodeid): test_result.stdout = report.capstdout test_result.stderr = report.capstderr + + if report.failed: + test_result.report_output = str(report.longrepr) diff --git a/tutorial/tests/testsuite/testsuite.py b/tutorial/tests/testsuite/testsuite.py index 30c12b57..ad289383 100644 --- a/tutorial/tests/testsuite/testsuite.py +++ b/tutorial/tests/testsuite/testsuite.py @@ -1,28 +1,36 @@ """A module to define the `%%ipytest` cell magic""" +import ast import dataclasses import inspect import io +import os import pathlib -import re from collections import defaultdict from contextlib import redirect_stderr, redirect_stdout from queue import Queue from threading import Thread -from typing import Callable, Dict, List, Optional +from typing import Dict, List, Optional import ipynbname import pytest +from dotenv import find_dotenv, load_dotenv from IPython.core.interactiveshell import InteractiveShell from IPython.core.magic import Magics, cell_magic, magics_class +from IPython.display import HTML, display +from .ai_helpers import OpenAIWrapper from .ast_parser import AstParser from .exceptions import ( FunctionNotFoundError, InstanceNotFoundError, + OpenAIWrapperError, + PytestInternalError, TestModuleNotFoundError, ) from .helpers import ( + AFunction, + DebugOutput, FunctionInjectionPlugin, IPytestOutcome, IPytestResult, @@ -32,11 +40,11 @@ ) -def run_test( - module_file: pathlib.Path, function_name: str, function_object: Callable +def run_pytest_for_function( + module_file: pathlib.Path, function: AFunction ) -> IPytestResult: """ - Run the tests for a single function + Runs pytest for a single function and returns an `IPytestResult` object """ with redirect_stdout(io.StringIO()) as _, redirect_stderr(io.StringIO()) as _: # Create the test collector @@ -44,9 +52,9 @@ def run_test( # Run the tests result = pytest.main( - ["-k", f"test_{function_name}", f"{module_file}"], + ["-k", f"test_{function.name}", f"{module_file}"], plugins=[ - FunctionInjectionPlugin(function_object), + FunctionInjectionPlugin(function.implementation), result_collector, ], ) @@ -54,7 +62,7 @@ def run_test( match result: case pytest.ExitCode.OK: return IPytestResult( - function_name=function_name, + function=function, status=IPytestOutcome.FINISHED, test_results=list(result_collector.tests.values()), ) @@ -64,7 +72,7 @@ def run_test( for test in result_collector.tests.values() ): return IPytestResult( - function_name=function_name, + function=function, status=IPytestOutcome.PYTEST_ERROR, exceptions=[ test.exception @@ -74,19 +82,19 @@ def run_test( ) return IPytestResult( - function_name=function_name, + function=function, status=IPytestOutcome.FINISHED, test_results=list(result_collector.tests.values()), ) case pytest.ExitCode.INTERNAL_ERROR: return IPytestResult( - function_name=function_name, + function=function, status=IPytestOutcome.PYTEST_ERROR, - exceptions=[Exception("Internal error")], + exceptions=[PytestInternalError()], ) case pytest.ExitCode.NO_TESTS_COLLECTED: return IPytestResult( - function_name=function_name, + function=function, status=IPytestOutcome.NO_TEST_FOUND, exceptions=[FunctionNotFoundError()], ) @@ -96,14 +104,13 @@ def run_test( ) -def run_test_in_thread( +def run_pytest_in_background( module_file: pathlib.Path, - function_name: str, - function_object: Callable, + function: AFunction, test_queue: Queue, ): - """Run the tests for a single function and put the result in the queue""" - test_queue.put(run_test(module_file, function_name, function_object)) + """Runs pytest in a background thread and puts the result in the provided queue""" + test_queue.put(run_pytest_for_function(module_file, function)) def _name_from_line(line: str = ""): @@ -142,9 +149,9 @@ class TestMagic(Magics): def __init__(self, shell): super().__init__(shell) - self.max_execution_count = 3 self.shell: InteractiveShell = shell self.cell: str = "" + self.debug: bool = False self.module_file: Optional[pathlib.Path] = None self.module_name: Optional[str] = None self.threaded: Optional[bool] = None @@ -155,38 +162,43 @@ def __init__(self, shell): self._orig_traceback = self.shell._showtraceback # type: ignore # This is monkey-patching suppress printing any exception or traceback - def extract_functions_to_test(self) -> Dict[str, Callable]: - """""" - # Retrieve the functions names defined in the current cell + def extract_functions_to_test(self) -> List[AFunction]: + """Retrieve the functions names and implementations defined in the current cell""" # Only functions with names starting with `solution_` will be candidates for tests - functions_names: List[str] = re.findall( - r"^(?:async\s+?)?def\s+(solution_.*?)\s*\(", self.cell, re.M - ) - - return { - name.removeprefix("solution_"): function + functions: Dict[str, str] = {} + tree = ast.parse(self.cell) + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name.startswith("solution_"): + functions.update({str(node.name): ast.unparse(node)}) + + return [ + AFunction( + name=name.removeprefix("solution_"), + implementation=function, + source_code=functions[name], + ) for name, function in self.shell.user_ns.items() - if name in functions_names + if name in functions and (callable(function) or inspect.iscoroutinefunction(function)) - } + ] - def run_test(self, function_name: str, function_object: Callable) -> IPytestResult: - """Run the tests for a single function""" + def run_test_with_tracking(self, function: AFunction) -> IPytestResult: + """Runs tests for a function while tracking execution count and handling threading""" assert isinstance(self.module_file, pathlib.Path) # Store execution count information for each cell cell_id = str(self.shell.parent_header["metadata"]["cellId"]) # type: ignore - self.cell_execution_count[cell_id][function_name] += 1 + self.cell_execution_count[cell_id][function.name] += 1 # Run the tests on a separate thread if self.threaded: assert isinstance(self.test_queue, Queue) thread = Thread( - target=run_test_in_thread, + target=run_pytest_in_background, args=( self.module_file, - function_name, - function_object, + function, self.test_queue, ), ) @@ -194,19 +206,19 @@ def run_test(self, function_name: str, function_object: Callable) -> IPytestResu thread.join() result = self.test_queue.get() else: - result = run_test(self.module_file, function_name, function_object) + result = run_pytest_for_function(self.module_file, function) match result.status: case IPytestOutcome.FINISHED: return dataclasses.replace( result, - test_attempts=self.cell_execution_count[cell_id][function_name], + test_attempts=self.cell_execution_count[cell_id][function.name], ) case _: return result def run_cell(self) -> List[IPytestResult]: - # Run the cell through IPython + """Evaluates the cell via IPython and runs tests for the functions""" try: result = self.shell.run_cell(self.cell, silent=True) # type: ignore result.raise_error() @@ -215,6 +227,7 @@ def run_cell(self) -> List[IPytestResult]: IPytestResult( status=IPytestOutcome.COMPILE_ERROR, exceptions=[err], + cell_content=self.cell, ) ] @@ -230,7 +243,7 @@ def run_cell(self) -> List[IPytestResult]: # Run the tests for each function test_results = [ - self.run_test(name, function) for name, function in functions_to_run.items() + self.run_test_with_tracking(function) for function in functions_to_run ] return test_results @@ -246,6 +259,11 @@ def ipytest(self, line: str, cell: str): self.cell = cell line_contents = set(line.split()) + # Debug mode? + if "debug" in line_contents: + line_contents.remove("debug") + self.debug = True + # Check if we need to run the tests on a separate thread if "async" in line_contents: line_contents.remove("async") @@ -253,8 +271,7 @@ def ipytest(self, line: str, cell: str): self.test_queue = Queue() # If debug is in the line, then we want to show the traceback - if "debug" in line_contents: - line_contents.remove("debug") + if self.debug: self.shell._showtraceback = self._orig_traceback else: self.shell._showtraceback = lambda *args, **kwargs: None @@ -280,16 +297,29 @@ def ipytest(self, line: str, cell: str): # Run the cell results = self.run_cell() + # If in debug mode, display debug information first + if self.debug: + debug_output = DebugOutput( + module_name=self.module_name, + module_file=self.module_file, + results=results, + ) + display(HTML(debug_output.to_html())) + # Parse the AST of the test module to retrieve the solution code ast_parser = AstParser(self.module_file) # Display the test results and the solution code for result in results: solution = ( - ast_parser.get_solution_code(result.function_name) - if result.function_name + ast_parser.get_solution_code(result.function.name) + if result.function and result.function.name else None ) - TestResultOutput(result, solution).display_results() + TestResultOutput( + result, + solution, + self.shell.openai_client, # type: ignore + ).display_results() def load_ipython_extension(ipython): @@ -298,5 +328,62 @@ def load_ipython_extension(ipython): can be loaded via `%load_ext module.path` or be configured to be autoloaded by IPython at startup time. """ + # Configure the API key for the OpenAI client + openai_env = find_dotenv("openai.env") + if openai_env: + load_dotenv(openai_env) + + api_key = os.getenv("OPENAI_API_KEY") + model = os.getenv("OPENAI_MODEL") + language = os.getenv("OPENAI_LANGUAGE") + + # First, validate the key + key_validation = OpenAIWrapper.validate_api_key(api_key) + if not key_validation.is_valid: + message = key_validation.user_message + message_color = "#ffebee" # Red + ipython.openai_client = None + else: + assert api_key is not None # must be so at this point + try: + openai_client, model_validation = OpenAIWrapper.create_validated( + api_key, model, language + ) + if model_validation.is_valid: + ipython.openai_client = openai_client + message_color = "#d9ead3" # Green + else: + message_color = "#ffebee" # Red + ipython.openai_client = None + + message = model_validation.user_message + except OpenAIWrapperError as e: + ipython.openai_client = None + message = f"🚫 OpenAI configuration error:
{str(e)}" + message_color = "#ffebee" + except Exception as e: + # Handle any other unexpected errors + ipython.openai_client = None + message = ( + f"🚫 Unexpected error:
{str(e)}" + ) + message_color = "#ffebee" + + display( + HTML( + "
" + f"{message}" + "
" + ) + ) + + # Register the magic ipython.register_magics(TestMagic) + + message = ( + "
" + "🔄 IPytest extension (re)loaded.
" + ) + display(HTML(message))