Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sandevh/stu 307 generate tests with an llm #1122

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# python
/venv
/venv2
/gen_tests

# dependencies
/node_modules
Expand Down
16 changes: 7 additions & 9 deletions blocks/HARDWARE/SOURCEMETERS/KEITHLEY/24XX/IV_SWEEP/IV_SWEEP.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
import serial
import numpy as np
from flojoy import SerialConnection, flojoy, OrderedPair, Vector
from typing import cast

import numpy as np
import serial
from flojoy import OrderedPair, SerialConnection, Vector, flojoy


@flojoy(deps={"pyserial": "3.5"}, inject_connection=True)
def IV_SWEEP(
connection: SerialConnection, default: OrderedPair | Vector
) -> OrderedPair:
"""Take an I-V curve measurement with a Keithley 2400 source meter (send voltages, measure currents).

Inputs
------
default: OrderedPair | Vector
The voltages to send to the Keithley 2400 source meter.

Parameters
----------
default: OrderedPair | Vector
The voltages to send to the Keithley 2400 source meter.
connection: Serial
The open connection with the Keithley2400 source meter.
The open connection with the Keithley 2400 source meter.

Returns
-------
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
{
"docstring": {
"long_description": "Inputs\n------\ndefault: OrderedPair | Vector\n The voltages to send to the Keithley 2400 source meter.",
"long_description": "",
"short_description": "Take an I-V curve measurement with a Keithley 2400 source meter (send voltages, measure currents).",
"parameters": [
{
"name": "default",
"type": "OrderedPair | Vector",
"description": "The voltages to send to the Keithley 2400 source meter."
},
{
"name": "connection",
"type": "Serial",
"description": "The open connection with the Keithley2400 source meter."
"description": "The open connection with the Keithley 2400 source meter."
}
],
"returns": [
Expand Down
18 changes: 14 additions & 4 deletions captain/models/test_sequencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ class LockedContextType(BaseModel):


class TestTypes(str, Enum):
Pytest = "Pytest"
Python = "Python"
Flojoy = "Flojoy"
Matlab = "Matlab"
pytest = "Pytest"
python = "Python"
flojoy = "Flojoy"
matlab = "Matlab"


class StatusTypes(str, Enum):
Expand Down Expand Up @@ -129,3 +129,13 @@ class TestSequenceRun(BaseModel):
data: Union[str, TestRootNode]
hardware_id: Union[str, None]
project_id: Union[str, None]


class GenerateTestRequest(BaseModel):
test_name: str = Field(..., alias="testName")
test_type: TestTypes = Field(..., alias="testType")
prompt: str = Field(..., alias="prompt")


class TestGenerationContainer(BaseModel):
test: Test
19 changes: 10 additions & 9 deletions captain/parser/bool_parser/bool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@
track_identifiers = set()


def _match_literal_or_var(s: str, ptr: int, allowed_symbols: set[str]):
izi-on marked this conversation as resolved.
Show resolved Hide resolved
start = ptr
while ptr < len(s) and s[ptr] in allowed_symbols:
ptr += 1
if ptr < len(s) and s[ptr] not in language:
raise InvalidCharacter(s[ptr])
return start, ptr


def _tokenize(s: str, symbol_table: SymbolTableType) -> list[Token]:
"""
Tokenizes the string input from the user
Expand All @@ -64,14 +73,6 @@ def _match_op_symbol(ptr: int) -> tuple[Operator, int]:
c = s[ptr : ptr + extend]
return Operator(c), extend

def _match_literal_or_var(ptr: int, allowed_symbols: set[str]):
start = ptr
while ptr < len(s) and s[ptr] in allowed_symbols:
ptr += 1
if ptr < len(s) and s[ptr] not in language:
raise InvalidCharacter(s[ptr])
return start, ptr

tokens: List[Token] = []
s = s.replace(" ", "") # remove whitespace
ptr = 0
Expand All @@ -97,7 +98,7 @@ def _match_literal_or_var(ptr: int, allowed_symbols: set[str]):
*boolean_literal_symbols,
*variable_symbols,
}
start_ptr, end_ptr = _match_literal_or_var(ptr, allowed_symbols)
start_ptr, end_ptr = _match_literal_or_var(s, ptr, allowed_symbols)
token_str = s[start_ptr:end_ptr]
if BooleanLiteral.allows(token_str):
tokens.append(BooleanLiteral(token_str))
Expand Down
10 changes: 10 additions & 0 deletions captain/parser/bool_parser/utils/name_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from captain.parser.bool_parser.bool_parser import variable_symbols
from captain.parser.bool_parser.expressions.exceptions import InvalidCharacter


def validate_name(name: str):
for c in name:
if c not in variable_symbols:
raise InvalidCharacter(
f"{c}, only {[s for s in variable_symbols]} is allowed"
)
22 changes: 21 additions & 1 deletion captain/routes/test_sequence.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import asyncio
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
import json
import pydantic
from captain.models.pytest.pytest_models import TestDiscoverContainer
from captain.models.test_sequencer import TestSequenceRun
from captain.models.test_sequencer import (
GenerateTestRequest,
TestGenerationContainer,
TestSequenceRun,
)
from captain.utils.test_sequencer.generator.generate_test import generate_test
from captain.utils.pytest.discover_tests import discover_pytest_file
from captain.utils.config import ts_manager
from captain.utils.test_sequencer.handle_data import handle_data
Expand Down Expand Up @@ -53,3 +59,17 @@ async def discover_pytest(params: DiscoverPytestParams = Depends()):
return TestDiscoverContainer(
response=return_val, missingLibraries=missing_lib
).model_dump_json(by_alias=True)


@router.post("/generate-test/")
async def generate_test_endpoint(requested_test: GenerateTestRequest):
test_container = {}
test_name = requested_test.test_name
test_type = requested_test.test_type
prompt = requested_test.prompt
await generate_test(test_name, test_type, prompt, test_container)
if "test" not in test_container:
return
return TestGenerationContainer(test=test_container["test"]).model_dump_json(
by_alias=True
)
35 changes: 35 additions & 0 deletions captain/utils/test_sequencer/data_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio
import traceback
from captain.utils.logger import logger
from captain.types.test_sequence import MsgState, TestSequenceMessage
from captain.utils.config import ts_manager


async def _stream_result_to_frontend(
state: MsgState,
test_id: str = "",
result: bool = False,
time_taken: float = 0,
is_saved_to_cloud: bool = False,
error: str | None = None,
):
asyncio.create_task(
ts_manager.ws.broadcast(
TestSequenceMessage(
state.value, test_id, result, time_taken, is_saved_to_cloud, error
)
)
)
await asyncio.sleep(0) # necessary for task yield
await asyncio.sleep(0) # still necessary for task yield


def _with_error_report(func):
async def reported_func(*args, **kwargs):
try:
await func(*args, **kwargs)
except Exception as e:
await _stream_result_to_frontend(state=MsgState.ERROR, error=str(e))
logger.error(f"{e}: {traceback.format_exc()}")

return reported_func
Empty file.
55 changes: 55 additions & 0 deletions captain/utils/test_sequencer/generator/generate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import uuid
from captain.models.test_sequencer import StatusTypes, Test, TestTypes
from captain.parser.bool_parser.utils.name_validator import validate_name
from captain.utils.logger import logger
import os
from openai import OpenAI
from captain.utils.test_sequencer.data_stream import _with_error_report

key = "" # TODO have this from .env, Joey said to hard code for now until his PR
client = OpenAI(api_key=key)


@_with_error_report
async def generate_test(
test_name: str, test_type: TestTypes, prompt: str, test_container: dict
):
if test_type.value != "Python": # for now only handle python tests
raise Exception("Only Python tests allowed")
validate_name(test_name) # this raises an error if invalid
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{
"role": "system",
"content": "You generate only raw python code for tests using assertions. If the user's request does not imply a test, simply output 'NULL'",
LatentDream marked this conversation as resolved.
Show resolved Hide resolved
},
{
"role": "user",
"content": prompt,
},
],
)
code_resp = response.choices[0].message.content
if code_resp is None:
raise Exception("Unable to generate code")
if code_resp == "NULL":
raise Exception("Invalid prompt/request, please specify something to test")
LatentDream marked this conversation as resolved.
Show resolved Hide resolved
code = "\n".join(code_resp.splitlines()[1:-1])
path_to_gen_folder = os.path.join(os.getcwd(), "gen_tests")
path_to_file = os.path.join(path_to_gen_folder, test_name)
if not os.path.exists(path_to_gen_folder):
os.makedirs(path_to_gen_folder)
with open(path_to_file, "w") as file:
file.write(code)
test_container["test"] = Test.construct(
type="test",
id=uuid.uuid4(),
group_id=uuid.uuid4(),
path=path_to_file,
test_name=test_name,
run_in_parallel=False,
test_type=TestTypes.python,
status=StatusTypes.pending,
is_saved_to_cloud=False,
)
Loading