Skip to content

Commit

Permalink
update: add class Config and move env into it
Browse files Browse the repository at this point in the history
  • Loading branch information
1martin1 committed Jan 23, 2025
1 parent 709c027 commit 694532c
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 88 deletions.
10 changes: 5 additions & 5 deletions protollm_tools/llm-worker/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

79 changes: 67 additions & 12 deletions protollm_tools/llm-worker/protollm_worker/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,71 @@
import os

REDIS_PREFIX = os.environ.get("REDIS_PREFIX", "llm-api")
REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
REDIS_PORT = os.environ.get("REDIS_PORT", "6379")

RABBIT_MQ_HOST = os.environ.get("RABBIT_MQ_HOST", "localhost")
RABBIT_MQ_PORT = os.environ.get("RABBIT_MQ_PORT", "5672")
RABBIT_MQ_LOGIN = os.environ.get("RABBIT_MQ_LOGIN", "admin")
RABBIT_MQ_PASSWORD = os.environ.get("RABBIT_MQ_PASSWORD", "admin")
class Config:
def __init__(
self,
redis_host: str = "localhost",
redis_port: int = 6379,
redis_prefix: str = "llm-api",
rabbit_host: str = "localhost",
rabbit_port: int = 5672,
rabbit_login: str = "admin",
rabbit_password: str = "admin",
queue_name: str = "llm-api-queue",
model_path: str = None,
token_len: int = None,
tensor_parallel_size: int = None,
gpu_memory_utilisation: float = None,
):
self.redis_host = redis_host
self.redis_port = redis_port
self.redis_prefix = redis_prefix
self.rabbit_host = rabbit_host
self.rabbit_port = rabbit_port
self.rabbit_login = rabbit_login
self.rabbit_password = rabbit_password
self.queue_name = queue_name
self.model_path = model_path,
self.token_len = token_len,
self.tensor_parallel_size = tensor_parallel_size,
self.gpu_memory_utilisation = gpu_memory_utilisation,

QUEUE_NAME = os.environ.get("QUEUE_NAME", "llm-api-queue")
MODEL_PATH = os.environ.get("MODEL_PATH")
TOKENS_LEN = int(os.environ.get("TOKENS_LEN"))
TENSOR_PARALLEL_SIZE = int(os.environ.get("TENSOR_PARALLEL_SIZE"))
GPU_MEMORY_UTILISATION = float(os.environ.get("GPU_MEMORY_UTILISATION"))
@classmethod
def read_from_env(cls) -> 'Config':
return Config(
os.environ.get("REDIS_HOST", "localhost"),
os.environ.get("REDIS_PORT", "6379"),
os.environ.get("REDIS_PREFIX", "llm-api"),
os.environ.get("RABBIT_MQ_HOST", "localhost"),
os.environ.get("RABBIT_MQ_PORT", "5672"),
os.environ.get("RABBIT_MQ_LOGIN", "admin"),
os.environ.get("RABBIT_MQ_PASSWORD", "admin"),
os.environ.get("QUEUE_NAME", "llm-api-queue"),
os.environ.get("MODEL_PATH"),
int(os.environ.get("TOKENS_LEN")),
int(os.environ.get("TENSOR_PARALLEL_SIZE")),
float(os.environ.get("GPU_MEMORY_UTILISATION")),
)

@classmethod
def read_from_env_file(cls, path: str) -> 'Config':
with open(path) as file:
lines = file.readlines()
env_vars = {}
for line in lines:
key, value = line.split("=")
env_vars[key] = value
return Config(
env_vars.get("REDIS_HOST", "localhost"),
int(env_vars.get("REDIS_PORT", "6379")),
env_vars.get("REDIS_PREFIX", "llm-api"),
env_vars.get("RABBIT_MQ_HOST", "localhost"),
int(env_vars.get("RABBIT_MQ_PORT", "5672")),
env_vars.get("RABBIT_MQ_LOGIN", "admin"),
env_vars.get("RABBIT_MQ_PASSWORD", "admin"),
env_vars.get("QUEUE_NAME", "llm-api-queue"),
env_vars.get("MODEL_PATH"),
int(env_vars.get("TOKENS_LEN")),
int(env_vars.get("TENSOR_PARALLEL_SIZE")),
float(env_vars.get("GPU_MEMORY_UTILISATION")),
)
16 changes: 3 additions & 13 deletions protollm_tools/llm-worker/protollm_worker/main.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
from protollm_worker.config import MODEL_PATH, REDIS_HOST, REDIS_PORT, QUEUE_NAME
from protollm_worker.models.vllm_models import VllMModel
from protollm_worker.services.broker import LLMWrap
from protollm_worker.config import (
RABBIT_MQ_HOST, RABBIT_MQ_PORT,
RABBIT_MQ_PASSWORD, RABBIT_MQ_LOGIN,
REDIS_PREFIX
)
from protollm_worker.config import Config

if __name__ == "__main__":
config = Config.read_from_env()
llm_model = VllMModel(model_path=MODEL_PATH)
llm_wrap = LLMWrap(llm_model=llm_model,
redis_host= REDIS_HOST,
redis_port= REDIS_PORT,
queue_name= QUEUE_NAME,
rabbit_host= RABBIT_MQ_HOST,
rabbit_port= RABBIT_MQ_PORT,
rabbit_login= RABBIT_MQ_LOGIN,
rabbit_password= RABBIT_MQ_PASSWORD,
redis_prefix= REDIS_PREFIX)
config= config)
llm_wrap.start_connection()
66 changes: 10 additions & 56 deletions protollm_tools/llm-worker/protollm_worker/services/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import pika
from protollm_sdk.models.job_context_models import PromptModel, ChatCompletionModel, PromptTransactionModel, \
PromptWrapper, ChatCompletionTransactionModel
from protollm_sdk.object_interface import RabbitMQWrapper
from protollm_sdk.object_interface.redis_wrapper import RedisWrapper

from protollm_worker.config import Config
from protollm_worker.models.base import BaseLLM

logging.basicConfig(level=logging.INFO)
Expand All @@ -22,83 +24,35 @@ class LLMWrap:

def __init__(self,
llm_model: BaseLLM,
redis_host: str,
redis_port: str,
queue_name: str,
rabbit_host: str,
rabbit_port: str,
rabbit_login: str,
rabbit_password: str,
redis_prefix: str):
config: Config):
"""
Initialize the LLMWrap class with the necessary configurations.
:param llm_model: The language model to use for processing prompts.
:type llm_model: BaseLLM
:param redis_host: Hostname for the Redis server.
:type redis_host: str
:param redis_port: Port for the Redis server.
:type redis_port: str
:param queue_name: Name of the RabbitMQ queue to consume messages from.
:type queue_name: str
:param rabbit_host: Hostname for the RabbitMQ server.
:type rabbit_host: str
:param rabbit_port: Port for the RabbitMQ server.
:type rabbit_port: str
:param rabbit_login: Login for RabbitMQ authentication.
:type rabbit_login: str
:param rabbit_password: Password for RabbitMQ authentication.
:type rabbit_password: str
:param redis_prefix: Prefix for Redis keys to store results.
:type redis_prefix: str
:param config: Set for setting Redis and RabbitMQ.
:type config: Config
"""
self.llm = llm_model
logger.info('Loaded model')

self.redis_bd = RedisWrapper(redis_host, redis_port)
self.redis_prefix = redis_prefix
self.redis_bd = RedisWrapper(config.redis_host, config.redis_port)
self.rabbitMQ = RabbitMQWrapper(config.rabbit_host, config.rabbit_port, config.rabbit_login, config.rabbit_password)
self.redis_prefix = config.redis_prefix
logger.info('Connected to Redis')

self.models = {
'single_generate': PromptModel,
'chat_completion': ChatCompletionModel,
}

self.queue_name = queue_name
self.rabbit_host = rabbit_host
self.rabbit_port = rabbit_port
self.rabbit_login = rabbit_login
self.rabbit_password = rabbit_password
self.queue_name = config.queue_name

def start_connection(self):
"""
Establish a connection to the RabbitMQ broker and start consuming messages from the specified queue.
"""
connection = pika.BlockingConnection(
pika.ConnectionParameters(
host=self.rabbit_host,
port=self.rabbit_port,
virtual_host='/',
credentials=pika.PlainCredentials(
username=self.rabbit_login,
password=self.rabbit_password
)
)
)

channel = connection.channel()
logger.info('Connected to the broker')

channel.queue_declare(queue=self.queue_name)
logger.info('Queue has been declared')

channel.basic_consume(
on_message_callback=self._callback,
queue=self.queue_name,
auto_ack=True
)

channel.start_consuming()
self.rabbitMQ.consume_messages(self.queue_name, self._callback)
logger.info('Started consuming messages')

def _dump_from_body(self, message_body) -> PromptModel | ChatCompletionModel:
Expand Down
2 changes: 1 addition & 1 deletion protollm_tools/llm-worker/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ python = "^3.10"
redis = "^5.0.5"
pika = "^1.3.2"
pydantic = "^2.7.4"
protollm_sdk = "^1.0.0"
protollm_sdk = "^1.1.0"
vllm = "^0.6.4.post1"

[toll.poetry.llama-cpp]
Expand Down
2 changes: 1 addition & 1 deletion protollm_tools/sdk/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "protollm-sdk"
version = "1.1.0"
version = "1.1.1"
description = ""
authors = ["aimclub"]
readme = "README.md"
Expand Down

0 comments on commit 694532c

Please sign in to comment.