diff --git a/protollm_tools/llm-worker/poetry.lock b/protollm_tools/llm-worker/poetry.lock index c478355..c3341fa 100644 --- a/protollm_tools/llm-worker/poetry.lock +++ b/protollm_tools/llm-worker/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -2846,13 +2846,13 @@ files = [ [[package]] name = "protollm-sdk" -version = "1.0.0" +version = "1.1.0" description = "" optional = false python-versions = "<4.0,>=3.10" files = [ - {file = "protollm_sdk-1.0.0-py3-none-any.whl", hash = "sha256:bd53331811e788c606551a7c19d2c59496612db8ac237f2c28a5bec6cf373c89"}, - {file = "protollm_sdk-1.0.0.tar.gz", hash = "sha256:7f34ab288115a33d44047d689703af731abd66af11846b5233c7083b3df8a8e7"}, + {file = "protollm_sdk-1.1.0-py3-none-any.whl", hash = "sha256:2f49516d0229a85fa0abf7d2bead2ed96a565686eb87b9559468eeb3198db7c7"}, + {file = "protollm_sdk-1.1.0.tar.gz", hash = "sha256:165d7d1270dadc7eacbf6d973f8410cd15fe2c7452cf6fb5f784cd5abc1f4c0e"}, ] [package.dependencies] @@ -5060,4 +5060,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "bef6d37fd1049fed6ba443ba2fa5abbcc3fff1499232944f872b272410ac659c" +content-hash = "1f1fe7c6d722ba2011a2d043068fd4c2aa8cdf5c7d7808e6f1aed30d75805300" diff --git a/protollm_tools/llm-worker/protollm_worker/config.py b/protollm_tools/llm-worker/protollm_worker/config.py index 2b6ef2a..513399c 100644 --- a/protollm_tools/llm-worker/protollm_worker/config.py +++ b/protollm_tools/llm-worker/protollm_worker/config.py @@ -1,16 +1,71 @@ import os -REDIS_PREFIX = os.environ.get("REDIS_PREFIX", "llm-api") -REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") -REDIS_PORT = os.environ.get("REDIS_PORT", "6379") -RABBIT_MQ_HOST = os.environ.get("RABBIT_MQ_HOST", "localhost") -RABBIT_MQ_PORT = os.environ.get("RABBIT_MQ_PORT", "5672") -RABBIT_MQ_LOGIN = os.environ.get("RABBIT_MQ_LOGIN", "admin") -RABBIT_MQ_PASSWORD = os.environ.get("RABBIT_MQ_PASSWORD", "admin") +class Config: + def __init__( + self, + redis_host: str = "localhost", + redis_port: int = 6379, + redis_prefix: str = "llm-api", + rabbit_host: str = "localhost", + rabbit_port: int = 5672, + rabbit_login: str = "admin", + rabbit_password: str = "admin", + queue_name: str = "llm-api-queue", + model_path: str = None, + token_len: int = None, + tensor_parallel_size: int = None, + gpu_memory_utilisation: float = None, + ): + self.redis_host = redis_host + self.redis_port = redis_port + self.redis_prefix = redis_prefix + self.rabbit_host = rabbit_host + self.rabbit_port = rabbit_port + self.rabbit_login = rabbit_login + self.rabbit_password = rabbit_password + self.queue_name = queue_name + self.model_path = model_path, + self.token_len = token_len, + self.tensor_parallel_size = tensor_parallel_size, + self.gpu_memory_utilisation = gpu_memory_utilisation, -QUEUE_NAME = os.environ.get("QUEUE_NAME", "llm-api-queue") -MODEL_PATH = os.environ.get("MODEL_PATH") -TOKENS_LEN = int(os.environ.get("TOKENS_LEN")) -TENSOR_PARALLEL_SIZE = int(os.environ.get("TENSOR_PARALLEL_SIZE")) -GPU_MEMORY_UTILISATION = float(os.environ.get("GPU_MEMORY_UTILISATION")) + @classmethod + def read_from_env(cls) -> 'Config': + return Config( + os.environ.get("REDIS_HOST", "localhost"), + os.environ.get("REDIS_PORT", "6379"), + os.environ.get("REDIS_PREFIX", "llm-api"), + os.environ.get("RABBIT_MQ_HOST", "localhost"), + os.environ.get("RABBIT_MQ_PORT", "5672"), + os.environ.get("RABBIT_MQ_LOGIN", "admin"), + os.environ.get("RABBIT_MQ_PASSWORD", "admin"), + os.environ.get("QUEUE_NAME", "llm-api-queue"), + os.environ.get("MODEL_PATH"), + int(os.environ.get("TOKENS_LEN")), + int(os.environ.get("TENSOR_PARALLEL_SIZE")), + float(os.environ.get("GPU_MEMORY_UTILISATION")), + ) + + @classmethod + def read_from_env_file(cls, path: str) -> 'Config': + with open(path) as file: + lines = file.readlines() + env_vars = {} + for line in lines: + key, value = line.split("=") + env_vars[key] = value + return Config( + env_vars.get("REDIS_HOST", "localhost"), + int(env_vars.get("REDIS_PORT", "6379")), + env_vars.get("REDIS_PREFIX", "llm-api"), + env_vars.get("RABBIT_MQ_HOST", "localhost"), + int(env_vars.get("RABBIT_MQ_PORT", "5672")), + env_vars.get("RABBIT_MQ_LOGIN", "admin"), + env_vars.get("RABBIT_MQ_PASSWORD", "admin"), + env_vars.get("QUEUE_NAME", "llm-api-queue"), + env_vars.get("MODEL_PATH"), + int(env_vars.get("TOKENS_LEN")), + int(env_vars.get("TENSOR_PARALLEL_SIZE")), + float(env_vars.get("GPU_MEMORY_UTILISATION")), + ) \ No newline at end of file diff --git a/protollm_tools/llm-worker/protollm_worker/main.py b/protollm_tools/llm-worker/protollm_worker/main.py index 9c6e6bc..993ecff 100644 --- a/protollm_tools/llm-worker/protollm_worker/main.py +++ b/protollm_tools/llm-worker/protollm_worker/main.py @@ -1,21 +1,11 @@ from protollm_worker.config import MODEL_PATH, REDIS_HOST, REDIS_PORT, QUEUE_NAME from protollm_worker.models.vllm_models import VllMModel from protollm_worker.services.broker import LLMWrap -from protollm_worker.config import ( - RABBIT_MQ_HOST, RABBIT_MQ_PORT, - RABBIT_MQ_PASSWORD, RABBIT_MQ_LOGIN, - REDIS_PREFIX -) +from protollm_worker.config import Config if __name__ == "__main__": + config = Config.read_from_env() llm_model = VllMModel(model_path=MODEL_PATH) llm_wrap = LLMWrap(llm_model=llm_model, - redis_host= REDIS_HOST, - redis_port= REDIS_PORT, - queue_name= QUEUE_NAME, - rabbit_host= RABBIT_MQ_HOST, - rabbit_port= RABBIT_MQ_PORT, - rabbit_login= RABBIT_MQ_LOGIN, - rabbit_password= RABBIT_MQ_PASSWORD, - redis_prefix= REDIS_PREFIX) + config= config) llm_wrap.start_connection() diff --git a/protollm_tools/llm-worker/protollm_worker/services/broker.py b/protollm_tools/llm-worker/protollm_worker/services/broker.py index 20db797..b5b3535 100644 --- a/protollm_tools/llm-worker/protollm_worker/services/broker.py +++ b/protollm_tools/llm-worker/protollm_worker/services/broker.py @@ -4,8 +4,10 @@ import pika from protollm_sdk.models.job_context_models import PromptModel, ChatCompletionModel, PromptTransactionModel, \ PromptWrapper, ChatCompletionTransactionModel +from protollm_sdk.object_interface import RabbitMQWrapper from protollm_sdk.object_interface.redis_wrapper import RedisWrapper +from protollm_worker.config import Config from protollm_worker.models.base import BaseLLM logging.basicConfig(level=logging.INFO) @@ -22,41 +24,21 @@ class LLMWrap: def __init__(self, llm_model: BaseLLM, - redis_host: str, - redis_port: str, - queue_name: str, - rabbit_host: str, - rabbit_port: str, - rabbit_login: str, - rabbit_password: str, - redis_prefix: str): + config: Config): """ Initialize the LLMWrap class with the necessary configurations. :param llm_model: The language model to use for processing prompts. :type llm_model: BaseLLM - :param redis_host: Hostname for the Redis server. - :type redis_host: str - :param redis_port: Port for the Redis server. - :type redis_port: str - :param queue_name: Name of the RabbitMQ queue to consume messages from. - :type queue_name: str - :param rabbit_host: Hostname for the RabbitMQ server. - :type rabbit_host: str - :param rabbit_port: Port for the RabbitMQ server. - :type rabbit_port: str - :param rabbit_login: Login for RabbitMQ authentication. - :type rabbit_login: str - :param rabbit_password: Password for RabbitMQ authentication. - :type rabbit_password: str - :param redis_prefix: Prefix for Redis keys to store results. - :type redis_prefix: str + :param config: Set for setting Redis and RabbitMQ. + :type config: Config """ self.llm = llm_model logger.info('Loaded model') - self.redis_bd = RedisWrapper(redis_host, redis_port) - self.redis_prefix = redis_prefix + self.redis_bd = RedisWrapper(config.redis_host, config.redis_port) + self.rabbitMQ = RabbitMQWrapper(config.rabbit_host, config.rabbit_port, config.rabbit_login, config.rabbit_password) + self.redis_prefix = config.redis_prefix logger.info('Connected to Redis') self.models = { @@ -64,41 +46,13 @@ def __init__(self, 'chat_completion': ChatCompletionModel, } - self.queue_name = queue_name - self.rabbit_host = rabbit_host - self.rabbit_port = rabbit_port - self.rabbit_login = rabbit_login - self.rabbit_password = rabbit_password + self.queue_name = config.queue_name def start_connection(self): """ Establish a connection to the RabbitMQ broker and start consuming messages from the specified queue. """ - connection = pika.BlockingConnection( - pika.ConnectionParameters( - host=self.rabbit_host, - port=self.rabbit_port, - virtual_host='/', - credentials=pika.PlainCredentials( - username=self.rabbit_login, - password=self.rabbit_password - ) - ) - ) - - channel = connection.channel() - logger.info('Connected to the broker') - - channel.queue_declare(queue=self.queue_name) - logger.info('Queue has been declared') - - channel.basic_consume( - on_message_callback=self._callback, - queue=self.queue_name, - auto_ack=True - ) - - channel.start_consuming() + self.rabbitMQ.consume_messages(self.queue_name, self._callback) logger.info('Started consuming messages') def _dump_from_body(self, message_body) -> PromptModel | ChatCompletionModel: diff --git a/protollm_tools/llm-worker/pyproject.toml b/protollm_tools/llm-worker/pyproject.toml index 3e9bdaf..b46de09 100644 --- a/protollm_tools/llm-worker/pyproject.toml +++ b/protollm_tools/llm-worker/pyproject.toml @@ -10,7 +10,7 @@ python = "^3.10" redis = "^5.0.5" pika = "^1.3.2" pydantic = "^2.7.4" -protollm_sdk = "^1.0.0" +protollm_sdk = "^1.1.0" vllm = "^0.6.4.post1" [toll.poetry.llama-cpp] diff --git a/protollm_tools/sdk/pyproject.toml b/protollm_tools/sdk/pyproject.toml index e9bdc7c..f8c075e 100644 --- a/protollm_tools/sdk/pyproject.toml +++ b/protollm_tools/sdk/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "protollm-sdk" -version = "1.1.0" +version = "1.1.1" description = "" authors = ["aimclub"] readme = "README.md"