Skip to content

Commit

Permalink
Add unit tests for modules.
Browse files Browse the repository at this point in the history
  • Loading branch information
tienhiep11 committed Jan 1, 2024
1 parent 0249e11 commit aa22e3f
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 0 deletions.
42 changes: 42 additions & 0 deletions tests/api/test_routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
from fastapi import WebSocketDisconnect
from src.api.routers import websocket_endpoint, manager
from unittest.mock import AsyncMock

from unittest.mock import patch

### Remind to read again Copilot chat for the test_routers.py file

@pytest.mark.asyncio
@patch('src.api.routers.get_openai_chain')
@patch('src.api.routers.StreamingLLMCallbackHandler')
@patch('src.api.routers.manager.connect')
async def test_websocket_endpoint_receives_and_sends_text(mock_connect, mock_handler, mock_chain):
mock_websocket = AsyncMock()
mock_websocket.receive_text.return_value = "Hello, world!"
await websocket_endpoint(mock_websocket)
mock_websocket.send_json.assert_called()
mock_connect.assert_called_once_with(mock_websocket)

@pytest.mark.asyncio
@patch('src.api.routers.get_openai_chain')
@patch('src.api.routers.StreamingLLMCallbackHandler')
@patch('src.api.routers.manager.connect')
@patch('src.api.routers.manager.disconnect')
async def test_websocket_endpoint_handles_disconnect(mock_disconnect, mock_connect, mock_handler, mock_chain):
mock_websocket = AsyncMock()
mock_websocket.receive_text.side_effect = WebSocketDisconnect()
await websocket_endpoint(mock_websocket)
mock_connect.assert_called_once_with(mock_websocket)
mock_disconnect.assert_called_once_with(mock_websocket)

@pytest.mark.asyncio
@patch('src.api.routers.get_openai_chain')
@patch('src.api.routers.StreamingLLMCallbackHandler')
@patch('src.api.routers.manager.connect')
async def test_websocket_endpoint_handles_exception(mock_connect, mock_handler, mock_chain):
mock_websocket = AsyncMock()
mock_websocket.receive_text.side_effect = Exception()
await websocket_endpoint(mock_websocket)
mock_websocket.send_json.assert_called()
mock_connect.assert_called_once_with(mock_websocket)
12 changes: 12 additions & 0 deletions tests/integrations/test_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest
from unittest.mock import patch, MagicMock
from src.integrations.openai import Settings

def test_settings():
settings = Settings()
assert settings.PROJECT_NAME == "llm-playground"
assert settings.API_VERSION == "v1"
assert settings.API_V1_STR == "/api/v1"
assert settings.MODEL_NAME == "gpt-3.5-turbo"
assert settings.TEMPERATURE == 0.7
assert settings.MAX_TOKENS == 2000
27 changes: 27 additions & 0 deletions tests/schemas/test_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest
from pydantic import ValidationError
from src.schemas.message import ChatResponse

def test_chat_response_valid():
# Test with valid data
response = ChatResponse(sender="bot", message="Hello, world!", type="start")
assert response.sender == "bot"
assert response.message == "Hello, world!"
assert response.type == "start"

def test_chat_response_invalid_sender():
# Test with invalid sender
with pytest.raises(ValidationError):
ChatResponse(sender="invalid", message="Hello, world!", type="start")

def test_chat_response_invalid_type():
# Test with invalid type
with pytest.raises(ValidationError):
ChatResponse(sender="bot", message="Hello, world!", type="invalid")

def test_chat_response_empty_message():
# Test with empty message
response = ChatResponse(sender="bot", message="", type="start")
assert response.sender == "bot"
assert response.message == ""
assert response.type == "start"
23 changes: 23 additions & 0 deletions tests/utils/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# FILEPATH: /Users/hiep/Desktop/Workspace/aitomatic_test/llm-playground/tests/utils/test_callbacks.py

import pytest
from unittest.mock import AsyncMock
from src.utils.callbacks import StreamingLLMCallbackHandler
from src.schemas.message import ChatResponse

@pytest.mark.asyncio
async def test_on_llm_new_token():
# Mock the websocket
mock_websocket = AsyncMock()

# Create an instance of the handler
handler = StreamingLLMCallbackHandler(mock_websocket)

# Call the on_llm_new_token method
await handler.on_llm_new_token('test_token')

# Create the expected response
expected_resp = ChatResponse(sender="bot", message='test_token', type="stream")

# Check that the websocket's send_json method was called with the correct argument
mock_websocket.send_json.assert_awaited_once_with(expected_resp.dict())
50 changes: 50 additions & 0 deletions tests/utils/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import sys
from unittest.mock import patch, MagicMock, call
from src.utils.logger import setup_logging, APP_LOGGER_NAME

@patch('logging.getLogger')
@patch('logging.StreamHandler')
def test_setup_logging_default(mock_stream_handler, mock_get_logger):
# Mock the logger and stream handler
mock_logger = MagicMock()
mock_get_logger.return_value = mock_logger
mock_handler = MagicMock()
mock_stream_handler.return_value = mock_handler

# Call the function
result = setup_logging()

# Check that the mocks were called with the correct arguments
mock_get_logger.assert_called_once_with(APP_LOGGER_NAME)
mock_stream_handler.assert_called_once_with(sys.stdout)
mock_logger.addHandler.assert_called_once_with(mock_handler)

# Check that the result is the mocked logger
assert result == mock_logger

@patch('logging.getLogger')
@patch('logging.StreamHandler')
@patch('logging.FileHandler')
def test_setup_logging_with_file(mock_file_handler, mock_stream_handler, mock_get_logger):
# Mock the logger and handlers
mock_logger = MagicMock()
mock_get_logger.return_value = mock_logger
mock_stream_handler_instance = MagicMock()
mock_stream_handler.return_value = mock_stream_handler_instance
mock_file_handler_instance = MagicMock()
mock_file_handler.return_value = mock_file_handler_instance

# Call the function
result = setup_logging(file_name='test.log')

# Check that the mocks were called with the correct arguments
mock_get_logger.assert_called_once_with(APP_LOGGER_NAME)
mock_stream_handler.assert_called_once_with(sys.stdout)
mock_file_handler.assert_called_once_with('test.log')

# Check that the logger's addHandler method was called with the correct arguments
calls = [call(mock_stream_handler_instance), call(mock_file_handler_instance)]
mock_logger.addHandler.assert_has_calls(calls)

# Check that the result is the mocked logger
assert result == mock_logger

0 comments on commit aa22e3f

Please sign in to comment.