Skip to content

Commit

Permalink
"feat: completing text /chat-completion and /completion provider and …
Browse files Browse the repository at this point in the history
…e1e tests"
  • Loading branch information
LESSuseLESS committed Feb 25, 2025
1 parent 19ae4b3 commit 056432f
Show file tree
Hide file tree
Showing 8 changed files with 480 additions and 224 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ include distributions/dependencies.json
include llama_stack/distribution/*.sh
include llama_stack/cli/scripts/*.sh
include llama_stack/templates/*/*.yaml
include llama_stack/providers/tests/test_cases/*.json
include llama_stack/providers/tests/test_cases/inference/*.json
180 changes: 114 additions & 66 deletions llama_stack/providers/tests/inference/test_text_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
SamplingParams,
StopReason,
ToolCall,
ToolDefinition,
ToolParamDefinition,
ToolPromptFormat,
)
from llama_stack.providers.tests.test_cases.test_case import TestCase
Expand Down Expand Up @@ -58,28 +56,6 @@ def common_params(inference_model):
}


@pytest.fixture
def sample_messages():
return [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What's the weather like today?"),
]


@pytest.fixture
def sample_tool_definition():
return ToolDefinition(
tool_name="get_weather",
description="Get the current weather",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
),
},
)


class TestInference:
# Session scope for asyncio because the tests in this class all
# share the same provider instance.
Expand All @@ -100,12 +76,20 @@ async def test_model_list(self, inference_model, inference_stack):

assert model_def is not None

@pytest.mark.parametrize(
"test_case",
[
"inference:completion:non_streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_completion(self, inference_model, inference_stack):
async def test_text_completion_non_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack

tc = TestCase(test_case)

response = await inference_impl.completion(
content="Micheael Jordan is born in ",
content=tc["content"],
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
Expand All @@ -114,12 +98,24 @@ async def test_completion(self, inference_model, inference_stack):
)

assert isinstance(response, CompletionResponse)
assert "1963" in response.content
assert tc["expected"] in response.content

@pytest.mark.parametrize(
"test_case",
[
"inference:completion:streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_text_completion_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack

tc = TestCase(test_case)

chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
content=tc["content"],
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
Expand All @@ -133,12 +129,20 @@ async def test_completion(self, inference_model, inference_stack):
last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens

@pytest.mark.parametrize(
"test_case",
[
"inference:completion:logprobs_non_streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_logprobs(self, inference_model, inference_stack):
async def test_text_completion_logprobs_non_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack

tc = TestCase(test_case)

response = await inference_impl.completion(
content="Micheael Jordan is born in ",
content=tc["content"],
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
Expand All @@ -154,10 +158,22 @@ async def test_completion_logprobs(self, inference_model, inference_stack):
assert response.logprobs, "Logprobs should not be empty"
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)

@pytest.mark.parametrize(
"test_case",
[
"inference:completion:logprobs_streaming",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_text_completion_logprobs_streaming(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack

tc = TestCase(test_case)

chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
content=tc["content"],
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
Expand All @@ -180,9 +196,14 @@ async def test_completion_logprobs(self, inference_model, inference_stack):
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"

@pytest.mark.parametrize("test_case", ["completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:structured_output",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_completion_structured_output(self, inference_model, inference_stack, test_case):
async def test_text_completion_structured_output(self, inference_model, inference_stack, test_case):
inference_impl, _ = inference_stack

class Output(BaseModel):
Expand Down Expand Up @@ -213,14 +234,20 @@ class Output(BaseModel):
assert answer.year_born == expected["year_born"]
assert answer.year_retired == expected["year_retired"]

@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
async def test_text_chat_completion_non_streaming(self, inference_model, inference_stack, common_params, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
messages=messages,
stream=False,
**common_params,
)
Expand All @@ -230,9 +257,16 @@ async def test_chat_completion_non_streaming(
assert isinstance(response.completion_message.content, str)
assert len(response.completion_message.content) > 0

@pytest.mark.parametrize("test_case", ["chat_completion-01"])
@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:structured_output",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_structured_output(self, inference_model, inference_stack, common_params, test_case):
async def test_text_chat_completion_structured_output(
self, inference_model, inference_stack, common_params, test_case
):
inference_impl, _ = inference_stack

class AnswerFormat(BaseModel):
Expand Down Expand Up @@ -281,14 +315,22 @@ class AnswerFormat(BaseModel):
with pytest.raises(ValidationError):
AnswerFormat.model_validate_json(response.completion_message.content)

@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_streaming(self, inference_model, inference_stack, common_params, sample_messages):
async def test_text_chat_completion_streaming(self, inference_model, inference_stack, common_params, test_case):
inference_impl, _ = inference_stack
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]
response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
messages=messages,
stream=True,
**common_params,
)
Expand All @@ -304,26 +346,28 @@ async def test_chat_completion_streaming(self, inference_model, inference_stack,
end = grouped[ChatCompletionResponseEventType.complete][0]
assert end.event.stop_reason == StopReason.end_of_turn

@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages_tool_calling",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling(
async def test_text_chat_completion_with_tool_calling(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
test_case,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]

response = await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
tools=tc["tools"],
stream=False,
**common_params,
)
Expand All @@ -339,32 +383,35 @@ async def test_chat_completion_with_tool_calling(
assert len(message.tool_calls) > 0

call = message.tool_calls[0]
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]

assert call.tool_name == tc["tools"][0]["tool_name"]
for name, value in tc["expected"].items():
assert name in call.arguments
assert value in call.arguments[name]

@pytest.mark.parametrize(
"test_case",
[
"inference:chat_completion:sample_messages_tool_calling",
],
)
@pytest.mark.asyncio(loop_scope="session")
async def test_chat_completion_with_tool_calling_streaming(
async def test_text_chat_completion_with_tool_calling_streaming(
self,
inference_model,
inference_stack,
common_params,
sample_messages,
sample_tool_definition,
test_case,
):
inference_impl, _ = inference_stack
messages = sample_messages + [
UserMessage(
content="What's the weather like in San Francisco?",
)
]
tc = TestCase(test_case)
messages = [TypeAdapter(Message).validate_python(m) for m in tc["messages"]]

response = [
r
async for r in await inference_impl.chat_completion(
model_id=inference_model,
messages=messages,
tools=[sample_tool_definition],
tools=tc["tools"],
stream=True,
**common_params,
)
Expand Down Expand Up @@ -397,6 +444,7 @@ async def test_chat_completion_with_tool_calling_streaming(
assert isinstance(last.event.delta.tool_call, ToolCall)

call = last.event.delta.tool_call
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
assert call.tool_name == tc["tools"][0]["tool_name"]
for name, value in tc["expected"].items():
assert name in call.arguments
assert value in call.arguments[name]
24 changes: 0 additions & 24 deletions llama_stack/providers/tests/test_cases/chat_completion.json

This file was deleted.

13 changes: 0 additions & 13 deletions llama_stack/providers/tests/test_cases/completion.json

This file was deleted.

Loading

0 comments on commit 056432f

Please sign in to comment.