Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch split task to token based splitting #283

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
61 changes: 19 additions & 42 deletions client/src/nv_ingest_client/primitives/tasks/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,18 @@

import logging
from typing import Dict
from typing import Literal
from typing import Optional

from pydantic import BaseModel, field_validator
from pydantic import BaseModel

from .task_base import Task

logger = logging.getLogger(__name__)


class SplitTaskSchema(BaseModel):
split_by: Optional[str] = "sentence"
split_length: Optional[int] = 10
split_overlap: Optional[int] = 0
max_character_length: Optional[int] = 1024
sentence_window_size: Optional[int] = 0

@field_validator("split_by")
def split_by_must_be_valid(cls, v):
valid_criteria = ["page", "size", "word", "sentence"]
if v not in valid_criteria:
raise ValueError(f"split_by must be one of {valid_criteria}")
return v
tokenizer: str = "intfloat/e5-large-unsupervised"
chunk_size: int = 300
chunk_overlap: int = 0

class Config:
extra = "forbid"
Expand All @@ -41,37 +30,29 @@ class SplitTask(Task):
Object for document splitting task
"""

_TypeSplitBy = Literal["word", "sentence", "passage"]

def __init__(
self,
split_by: _TypeSplitBy = None,
split_length: int = None,
split_overlap: int = None,
max_character_length: int = None,
sentence_window_size: int = None,
tokenizer: str = "intfloat/e5-large-unsupervised",
chunk_size: int = 300,
chunk_overlap: int = 0,
) -> None:
"""
Setup Split Task Config
"""
super().__init__()
self._split_by = split_by
self._split_length = split_length
self._split_overlap = split_overlap
self._max_character_length = max_character_length
self._sentence_window_size = sentence_window_size
self._tokenizer = tokenizer
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap

def __str__(self) -> str:
"""
Returns a string with the object's config and run time state
"""
info = ""
info += "Split Task:\n"
info += f" split_by: {self._split_by}\n"
info += f" split_length: {self._split_length}\n"
info += f" split_overlap: {self._split_overlap}\n"
info += f" split_max_character_length: {self._max_character_length}\n"
info += f" split_sentence_window_size: {self._sentence_window_size}\n"
info += f" tokenizer: {self._tokenizer}\n"
info += f" chunk_size: {self._chunk_size}\n"
info += f" chunk_overlap: {self._chunk_overlap}\n"
return info

def to_dict(self) -> Dict:
Expand All @@ -80,15 +61,11 @@ def to_dict(self) -> Dict:
"""
split_params = {}

if self._split_by is not None:
split_params["split_by"] = self._split_by
if self._split_length is not None:
split_params["split_length"] = self._split_length
if self._split_overlap is not None:
split_params["split_overlap"] = self._split_overlap
if self._max_character_length is not None:
split_params["max_character_length"] = self._max_character_length
if self._sentence_window_size is not None:
split_params["sentence_window_size"] = self._sentence_window_size
if self._tokenizer is not None:
split_params["tokenizer"] = self._tokenizer
if self._chunk_size is not None:
split_params["chunk_size"] = self._chunk_size
if self._chunk_overlap is not None:
split_params["chunk_overlap"] = self._chunk_overlap

return {"type": "split", "task_properties": split_params}
4 changes: 2 additions & 2 deletions src/nv_ingest/modules/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
# SPDX-License-Identifier: Apache-2.0

from .associate_nearby_text import AssociateNearbyTextLoaderFactory
from .nemo_doc_splitter import NemoDocSplitterLoaderFactory
from .text_splitter import TextSplitterLoaderFactory

__all__ = ["NemoDocSplitterLoaderFactory", "AssociateNearbyTextLoaderFactory"]
__all__ = ["TextSplitterLoaderFactory", "AssociateNearbyTextLoaderFactory"]
223 changes: 0 additions & 223 deletions src/nv_ingest/modules/transforms/nemo_doc_splitter.py
ChrisJar marked this conversation as resolved.
Outdated
Show resolved Hide resolved

This file was deleted.

Loading
Loading