Skip to content

Commit

Permalink
Merge pull request #4 from flare-research/openrouter-typing
Browse files Browse the repository at this point in the history
fix(typing): strictly type openrouter, simplify imports and routing logic
  • Loading branch information
magurh authored Feb 17, 2025
2 parents 08be5ca + 36475da commit da9edcd
Show file tree
Hide file tree
Showing 23 changed files with 380 additions and 351 deletions.
42 changes: 0 additions & 42 deletions src/flare_ai_consensus/config.py

This file was deleted.

8 changes: 8 additions & 0 deletions src/flare_ai_consensus/consensus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .aggregator import async_centralized_llm_aggregator, centralized_llm_aggregator
from .consensus import send_round

__all__ = [
"async_centralized_llm_aggregator",
"centralized_llm_aggregator",
"send_round",
]
32 changes: 18 additions & 14 deletions src/flare_ai_consensus/consensus/aggregator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from flare_ai_consensus.consensus.config import AggregatorConfig
from flare_ai_consensus.router.client import AsyncOpenRouterClient, OpenRouterClient
from flare_ai_consensus.router import (
AsyncOpenRouterProvider,
ChatRequest,
OpenRouterProvider,
)
from flare_ai_consensus.settings import AggregatorConfig, Message


def concatenate_aggregator(responses: dict[str, str]) -> str:
def _concatenate_aggregator(responses: dict[str, str]) -> str:
"""
Aggregate responses by concatenating each model's answer with a label.
Expand All @@ -13,52 +17,52 @@ def concatenate_aggregator(responses: dict[str, str]) -> str:


def centralized_llm_aggregator(
client: OpenRouterClient,
provider: OpenRouterProvider,
aggregator_config: AggregatorConfig,
aggregated_responses: dict[str, str],
) -> str:
"""Use a centralized LLM to combine responses.
:param client: An OpenRouterClient instance.
:param provider: An OpenRouterProvider instance.
:param aggregator_config: An instance of AggregatorConfig.
:param aggregated_responses: A string containing aggregated
responses from individual models.
:return: The aggregator's combined response.
"""
# Build the message list.
messages = []
messages: list[Message] = []
messages.extend(aggregator_config.context)

# Add a system message with the aggregated responses.
aggregated_str = concatenate_aggregator(aggregated_responses)
aggregated_str = _concatenate_aggregator(aggregated_responses)
messages.append(
{"role": "system", "content": f"Aggregated responses:\n{aggregated_str}"}
)

# Add the aggregator prompt
messages.extend(aggregator_config.prompt)

payload = {
payload: ChatRequest = {
"model": aggregator_config.model.model_id,
"messages": messages,
"max_tokens": aggregator_config.model.max_tokens,
"temperature": aggregator_config.model.temperature,
}

# Get aggregated response from the centralized LLM
response = client.send_chat_completion(payload)
response = provider.send_chat_completion(payload)
return response.get("choices", [])[0].get("message", {}).get("content", "")


async def async_centralized_llm_aggregator(
client: AsyncOpenRouterClient,
provider: AsyncOpenRouterProvider,
aggregator_config: AggregatorConfig,
aggregated_responses: dict[str, str],
) -> str:
"""
Use a centralized LLM (via an async client) to combine responses.
Use a centralized LLM (via an async provider) to combine responses.
:param client: An asynchronous OpenRouter client.
:param provider: An asynchronous OpenRouterProvider.
:param aggregator_config: An instance of AggregatorConfig.
:param aggregated_responses: A string containing aggregated
responses from individual models.
Expand All @@ -71,12 +75,12 @@ async def async_centralized_llm_aggregator(
)
messages.extend(aggregator_config.prompt)

payload = {
payload: ChatRequest = {
"model": aggregator_config.model.model_id,
"messages": messages,
"max_tokens": aggregator_config.model.max_tokens,
"temperature": aggregator_config.model.temperature,
}

response = await client.send_chat_completion(payload)
response = await provider.send_chat_completion(payload)
return response.get("choices", [])[0].get("message", {}).get("content", "")
65 changes: 0 additions & 65 deletions src/flare_ai_consensus/consensus/config.py

This file was deleted.

28 changes: 14 additions & 14 deletions src/flare_ai_consensus/consensus/consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

import structlog

from flare_ai_consensus.consensus.config import ConsensusConfig, ModelConfig
from flare_ai_consensus.router.client import AsyncOpenRouterClient
from flare_ai_consensus.utils.parser import parse_chat_response
from flare_ai_consensus.router import AsyncOpenRouterProvider, ChatRequest
from flare_ai_consensus.settings import ConsensusConfig, Message, ModelConfig
from flare_ai_consensus.utils import parse_chat_response

logger = structlog.get_logger(__name__)


def build_improvement_conversation(
def _build_improvement_conversation(
consensus_config: ConsensusConfig, aggregated_response: str
) -> list:
) -> list[Message]:
"""Build an updated conversation using the consensus configuration.
:param consensus_config: An instance of ConsensusConfig.
Expand All @@ -35,16 +35,16 @@ def build_improvement_conversation(
return conversation


async def get_response_for_model(
client: AsyncOpenRouterClient,
async def _get_response_for_model(
provider: AsyncOpenRouterProvider,
consensus_config: ConsensusConfig,
model: ModelConfig,
aggregated_response: str | None,
) -> tuple[str | None, str]:
"""
Asynchronously sends a chat completion request for a given model.
:param client: An instance of an asynchronous OpenRouter client.
:param provider: An instance of an asynchronous OpenRouter provider.
:param consensus_config: An instance of ConsensusConfig.
:param aggregated_response: The aggregated consensus response
from the previous round (or None).
Expand All @@ -57,39 +57,39 @@ async def get_response_for_model(
logger.info("sending initial prompt", model_id=model.model_id)
else:
# Build the improvement conversation.
conversation = build_improvement_conversation(
conversation = _build_improvement_conversation(
consensus_config, aggregated_response
)
logger.info("sending improvement prompt", model_id=model.model_id)

payload = {
payload: ChatRequest = {
"model": model.model_id,
"messages": conversation,
"max_tokens": model.max_tokens,
"temperature": model.temperature,
}
response = await client.send_chat_completion(payload)
response = await provider.send_chat_completion(payload)
text = parse_chat_response(response)
logger.info("new response", model_id=model.model_id, response=text)
return model.model_id, text


async def send_round(
client: AsyncOpenRouterClient,
provider: AsyncOpenRouterProvider,
consensus_config: ConsensusConfig,
aggregated_response: str | None = None,
) -> dict:
"""
Asynchronously sends a round of chat completion requests for all models.
:param client: An instance of an asynchronous OpenRouter client.
:param provider: An instance of an asynchronous OpenRouter provider.
:param consensus_config: An instance of ConsensusConfig.
:param aggregated_response: The aggregated consensus response from the
previous round (or None).
:return: A dictionary mapping model IDs to their response texts.
"""
tasks = [
get_response_for_model(client, consensus_config, model, aggregated_response)
_get_response_for_model(provider, consensus_config, model, aggregated_response)
for model in consensus_config.models
]
results = await asyncio.gather(*tasks)
Expand Down
Loading

0 comments on commit da9edcd

Please sign in to comment.