diff --git a/poetry.lock b/poetry.lock index eb0384f..3ec1d72 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +[[package]] +name = "aiofiles" +version = "24.1.0" +description = "File support for asyncio." +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, + {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, +] + [[package]] name = "aiohttp" version = "3.9.5" @@ -4194,4 +4205,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "678a8e8edc4926462efbf5699598d6e14a4c281dbb5e37285257600a1871145a" +content-hash = "33d9f7e57759cf256a1c5a4149553bf72bffe6a8266fec561cc1d49a134721c5" diff --git a/pyproject.toml b/pyproject.toml index 4c4102f..2d51c81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ pyyaml = "^6.0.1" cerberus = "^1.3.5" pydantic = "^2.7.3" setuptools = "^70.0.0" +aiofiles = "^24.1.0" [tool.poetry.group.dev.dependencies] black = "^24.4.2" diff --git a/ragulate/cli_commands/download.py b/ragulate/cli_commands/download.py index 044cbbe..2944a2a 100644 --- a/ragulate/cli_commands/download.py +++ b/ragulate/cli_commands/download.py @@ -1,4 +1,4 @@ -from ragulate.datasets import LlamaDataset +from ragulate.datasets import get_dataset def setup_download(subparsers): @@ -22,7 +22,5 @@ def setup_download(subparsers): def call_download(dataset_name: str, kind: str, **kwargs): - if not kind == "llama": - raise ("Currently only Llama Datasets are supported. Set param `-k llama`") - llama = LlamaDataset(dataset_name=dataset_name) - llama.download_dataset() + dataset = get_dataset(name=dataset_name, kind=kind) + dataset.download_dataset() diff --git a/ragulate/cli_commands/query.py b/ragulate/cli_commands/query.py index 96f2148..cf7e0d2 100644 --- a/ragulate/cli_commands/query.py +++ b/ragulate/cli_commands/query.py @@ -55,6 +55,15 @@ def setup_query(subparsers): help=("The name of a dataset to query", "This can be passed multiple times."), action="append", ) + query_parser.add_argument( + "--subset", + type=str, + help=( + "The subset of the dataset to query", + "Only valid when a single dataset is passed.", + ), + action="append", + ) query_parser.set_defaults(func=lambda args: call_query(**vars(args))) def call_query( @@ -64,10 +73,20 @@ def call_query( var_name: List[str], var_value: List[str], dataset: List[str], + subset: List[str], **kwargs, ): + datasets = [find_dataset(name=name) for name in dataset] + if len(subset) > 0: + if len(datasets) > 1: + raise ValueError( + "Only can set `subset` param when there is one dataset" + ) + else: + datasets[0].subsets = subset + ingredients = convert_vars_to_ingredients( var_names=var_name, var_values=var_value ) diff --git a/ragulate/datasets/__init__.py b/ragulate/datasets/__init__.py index 9648291..42abadb 100644 --- a/ragulate/datasets/__init__.py +++ b/ragulate/datasets/__init__.py @@ -1,9 +1,11 @@ from .base_dataset import BaseDataset +from .crag_dataset import CragDataset from .llama_dataset import LlamaDataset from .utils import find_dataset, get_dataset __all__ = [ "BaseDataset", + "CragDataset", "LlamaDataset", "find_dataset", "get_dataset", diff --git a/ragulate/datasets/base_dataset.py b/ragulate/datasets/base_dataset.py index 3c22dd6..aba5694 100644 --- a/ragulate/datasets/base_dataset.py +++ b/ragulate/datasets/base_dataset.py @@ -1,13 +1,20 @@ +import bz2 +import tempfile from abc import ABC, abstractmethod -from os import path +from os import makedirs, path from pathlib import Path from typing import Dict, List, Optional, Tuple +import aiofiles +import aiohttp +from tqdm.asyncio import tqdm + class BaseDataset(ABC): root_storage_path: str name: str + _subsets: List[str] = [] def __init__( self, dataset_name: str, root_storage_path: Optional[str] = "datasets" @@ -27,6 +34,14 @@ def list_files_at_path(self, path: str) -> List[str]: if f.is_file() and not f.name.startswith(".") ] + @property + def subsets(self) -> List[str]: + return self._subsets + + @subsets.setter + def subsets(self, value: List[str]): + self._subsets = value + @abstractmethod def sub_storage_path(self) -> str: """the sub-path to store the dataset in""" @@ -42,3 +57,36 @@ def get_source_file_paths(self) -> List[str]: @abstractmethod def get_queries_and_golden_set(self) -> Tuple[List[str], List[Dict[str, str]]]: """gets a list of queries and golden_truth answers for a dataset""" + + async def _download_file(self, session, url, temp_file_path): + async with session.get(url) as response: + file_size = int(response.headers.get('Content-Length', 0)) + chunk_size = 1024 + with tqdm(total=file_size, unit='B', unit_scale=True, desc=f'Downloading {url.split("/")[-1]}') as progress_bar: + async with aiofiles.open(temp_file_path, 'wb') as temp_file: + async for chunk in response.content.iter_chunked(chunk_size): + await temp_file.write(chunk) + progress_bar.update(len(chunk)) + + async def _decompress_file(self, temp_file_path, output_file_path): + makedirs(path.dirname(output_file_path), exist_ok=True) + with open(temp_file_path, 'rb') as temp_file: + decompressed_size = 0 + with bz2.BZ2File(temp_file, 'rb') as bz2_file: + async with aiofiles.open(output_file_path, 'wb') as output_file: + with tqdm(unit='B', unit_scale=True, desc=f'Decompressing {output_file_path}') as progress_bar: + while True: + chunk = bz2_file.read(1024) + if not chunk: + break + await output_file.write(chunk) + decompressed_size += len(chunk) + progress_bar.update(len(chunk)) + + async def _download_and_decompress(self, url, output_file_path): + async with aiohttp.ClientSession() as session: + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_file_path = temp_file.name + + await self._download_file(session, url, temp_file_path) + await self._decompress_file(temp_file_path, output_file_path) diff --git a/ragulate/datasets/utils.py b/ragulate/datasets/utils.py index db7b18a..60766d0 100644 --- a/ragulate/datasets/utils.py +++ b/ragulate/datasets/utils.py @@ -1,17 +1,32 @@ +import os from typing import List from .base_dataset import BaseDataset +from .crag_dataset import CragDataset from .llama_dataset import LlamaDataset -# TODO: implement this when adding additional dataset kinds def find_dataset(name:str) -> BaseDataset: + root_path = "datasets" + name = name.lower() + for kind in os.listdir(root_path): + kind_path = os.path.join(root_path, kind) + if os.path.isdir(kind_path): + for dataset in os.listdir(kind_path): + dataset_path = os.path.join(kind_path, dataset) + if os.path.isdir(dataset_path): + if dataset.lower() == name: + return get_dataset(name, kind) + """ searches for a downloaded dataset with this name. if found, returns it.""" return get_dataset(name, "llama") def get_dataset(name:str, kind:str) -> BaseDataset: + kind = kind.lower() if kind == "llama": return LlamaDataset(dataset_name=name) + elif kind == "crag": + return CragDataset(dataset_name=name) - raise NotImplementedError("only llama datasets are currently supported") + raise NotImplementedError("only llama and crag datasets are currently supported")