Skip to content

Commit

Permalink
Adding mode knob in config under classifier (#538)
Browse files Browse the repository at this point in the history
* Adding mode knob in config under classifier

* Updating classification mode request

* Incorporated code review requests

* Added classifier_mode as req param in /loader/doc API

* Incorporated review requests, lint fixes

* Added default value
  • Loading branch information
dristysrivastava authored Sep 11, 2024
1 parent 6a6265e commit 6ebdf6e
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 59 deletions.
3 changes: 2 additions & 1 deletion docs/gh_pages/docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ Notes:

### Classifier

- `mode`: Specifies mode for classify API. Possible values are `all`, `entity` or `topic`. Default value is `all`. When its value is `all`, both entities and topics will get classified, if value is `entity`, only entities will get classified and vice-versa. It is used for classification in /classify and /loader/doc APIs.
- `anonymizeSnippets`: Flag to anonymize snippets in report. Possible values are 'True' and 'False'. When its value is 'True', snippets in reports will be shown as anonymized and vice versa.

### Storage
This is beta feature introduced in 0.1.18.
- `type`: Specifies storage type to store states of the GenAI applications. Possible values are `file` or `db`. Default value is `file`. By default SQLite database is used when we set it as `db`.
- `type`: Specifies storage type to store states of the GenAI applications. Possible values are `file` or `db`. Default value is `file`. By default, SQLite database is used when we set it as `db`.
- `type` as `file` is deprecated, use `type` as `db`. `file` would not be supported from 0.1.19 release.

### Default Configuration
Expand Down
13 changes: 12 additions & 1 deletion pebblo/app/api/req_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from typing import List, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, Field

from pebblo.app.enums.common import ClassificationMode


class Runtime(BaseModel):
Expand Down Expand Up @@ -92,3 +94,12 @@ class ReqPrompt(BaseModel):

class ReqPromptGov(BaseModel):
prompt: str


class ReqClassifier(BaseModel):
data: str
mode: Optional[ClassificationMode] = Field(default=ClassificationMode.ALL)
anonymize: Optional[bool] = Field(default=False)

class Config:
extra = "forbid"
7 changes: 5 additions & 2 deletions pebblo/app/api/v1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import APIRouter

from pebblo.app.api.req_models import ReqClassifier
from pebblo.app.config.config import var_server_config_dict
from pebblo.app.service.classification import Classification

Expand All @@ -15,7 +16,9 @@ def __init__(self, prefix: str):
self.router = APIRouter(prefix=prefix)

@staticmethod
def classify_data(data: dict):
cls_obj = Classification(data)
def classify_data(data: ReqClassifier):
# "/classify" API entrypoint
# Execute entity/topic classification
cls_obj = Classification(data.model_dump())
response = cls_obj.process_request()
return response
7 changes: 5 additions & 2 deletions pebblo/app/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic_settings import BaseSettings

from pebblo.app.config.config_validation import validate_config, validate_input
from pebblo.app.enums.common import DBStorageTypes, StorageTypes
from pebblo.app.enums.common import ClassificationMode, DBStorageTypes, StorageTypes

# Default config value
dir_path = pathlib.Path().absolute()
Expand Down Expand Up @@ -44,6 +44,7 @@ class LoggingConfig(BaseSettings):


class ClassifierConfig(BaseSettings):
mode: str = Field(default=ClassificationMode.ALL.value)
anonymizeSnippets: bool = Field(default=True)


Expand Down Expand Up @@ -77,7 +78,9 @@ def load_config(path: Optional[str]) -> Tuple[dict, Config]:
format="pdf", renderer="xhtml2pdf", cacheDir="~/.pebblo"
),
logging=LoggingConfig(),
classifier=ClassifierConfig(anonymizeSnippets=False),
classifier=ClassifierConfig(
mode=ClassificationMode.ALL.value, anonymizeSnippets=False
),
storage=StorageConfig(type="file", db=None),
# for now, a default storage type is FILE, but in the next release DB will be the default storage type.
)
Expand Down
1 change: 1 addition & 0 deletions pebblo/app/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ reports:
renderer: xhtml2pdf
cacheDir: ~/.pebblo
classifier:
mode: all
anonymizeSnippets: False
storage:
type: file
10 changes: 9 additions & 1 deletion pebblo/app/config/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from abc import ABC, abstractmethod

from pebblo.app.enums.common import DBStorageTypes, StorageTypes
from pebblo.app.enums.common import ClassificationMode, DBStorageTypes, StorageTypes


class ConfigValidator(ABC):
Expand Down Expand Up @@ -136,7 +136,15 @@ def validate_input(input_dict):

class ClassifierConfig(ConfigValidator):
def validate(self):
mode = self.config.get("mode")
anonymize_snippets = self.config.get("anonymizeSnippets")
valid_classification_modes = [
classification_mode.value for classification_mode in ClassificationMode
]
if mode not in valid_classification_modes:
self.errors.append(
f"Error: Unsupported classifier mode '{mode}' specified in the configuration. Valid values are {valid_classification_modes}"
)
if not isinstance(anonymize_snippets, bool):
self.errors.append(
f"Error: Invalid anonymizeSnippets '{anonymize_snippets}'. anonymizeSnippets must be a boolean."
Expand Down
6 changes: 6 additions & 0 deletions pebblo/app/enums/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,9 @@ class StorageTypes(Enum):
class DBStorageTypes(Enum):
SQLITE = "sqlite"
MONGODB = "mongodb"


class ClassificationMode(Enum):
ALL = "all"
ENTITY = "entity"
TOPIC = "topic"
44 changes: 20 additions & 24 deletions pebblo/app/service/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,19 @@
"""

import traceback
from enum import Enum
from typing import Optional

from pydantic import BaseModel, Field, ValidationError
from pydantic import ValidationError

from pebblo.app.api.req_models import ReqClassifier
from pebblo.app.config.config import var_server_config_dict
from pebblo.app.enums.common import ClassificationMode
from pebblo.app.libs.responses import PebbloJsonResponse
from pebblo.app.models.models import AiDataModel
from pebblo.entity_classifier.entity_classifier import EntityClassifier
from pebblo.log import get_logger
from pebblo.topic_classifier.topic_classifier import TopicClassifier


class ClassificationMode(Enum):
ENTITY = "entity"
TOPIC = "topic"
ALL = "all"


class ReqClassifier(BaseModel):
data: str
mode: Optional[ClassificationMode] = Field(default=ClassificationMode.ALL)
anonymize: Optional[bool] = Field(default=False)

class Config:
extra = "forbid"
config_details = var_server_config_dict.get()


logger = get_logger(__name__)
Expand All @@ -40,10 +28,11 @@ class Classification:
Classification wrapper class for Entity and Semantic classification with anonymization
"""

def __init__(self, input: dict):
self.input = input
def __init__(self, data: dict):
self.input = data

def _get_classifier_response(self, req: ReqClassifier):
@staticmethod
def _get_classifier_response(req: ReqClassifier):
"""
Processes the input prompt through the entity classifier and anonymizer, and returns
the resulting information encapsulated in an AiDataModel object.
Expand All @@ -52,7 +41,7 @@ def _get_classifier_response(self, req: ReqClassifier):
AiDataModel: An object containing the anonymized document, entities, and their counts.
"""
doc_info = AiDataModel(
data=None,
data=req.data,
entities={},
entityCount=0,
entityDetails={},
Expand All @@ -62,7 +51,10 @@ def _get_classifier_response(self, req: ReqClassifier):
)
try:
# Process entity classification
if req.mode in [ClassificationMode.ENTITY, ClassificationMode.ALL]:
if req.mode in [
ClassificationMode.ENTITY,
ClassificationMode.ALL,
]:
(
entities,
entity_count,
Expand All @@ -75,10 +67,14 @@ def _get_classifier_response(self, req: ReqClassifier):
doc_info.entities = entities
doc_info.entityCount = entity_count
doc_info.entityDetails = entity_details
doc_info.data = anonymized_doc if req.anonymize else ""
if req.anonymize:
doc_info.data = anonymized_doc

# Process topic classification
if req.mode in [ClassificationMode.TOPIC, ClassificationMode.ALL]:
if req.mode in [
ClassificationMode.TOPIC,
ClassificationMode.ALL,
]:
topics, topic_count, topic_details = topic_classifier_obj.predict(
req.data
)
Expand Down
57 changes: 38 additions & 19 deletions pebblo/app/service/loader/loader_doc_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from datetime import datetime
from os import makedirs, path

from pebblo.app.config.config import var_server_config_dict
from pebblo.app.enums.common import ClassificationMode
from pebblo.app.enums.enums import ApplicationTypes, CacheDir, ClassifierConstants
from pebblo.app.libs.responses import PebbloJsonResponse
from pebblo.app.models.db_models import (
Expand All @@ -25,6 +27,7 @@
from pebblo.reports.reports import Reports
from pebblo.topic_classifier.topic_classifier import TopicClassifier

config_details = var_server_config_dict.get()
logger = get_logger(__name__)

# Init topic classifier
Expand All @@ -36,6 +39,7 @@ def __init__(self):
self.db = None
self.data = None
self.app_name = None
self.classifier_mode = None
self.entity_classifier_obj = EntityClassifier()

@staticmethod
Expand Down Expand Up @@ -177,25 +181,33 @@ def _get_doc_classification(self, doc):
)
try:
if doc_info.data:
topics, topic_count, topic_details = topic_classifier_obj.predict(
doc_info.data
)
(
entities,
entity_count,
anonymized_doc,
entity_details,
) = self.entity_classifier_obj.presidio_entity_classifier_and_anonymizer(
doc_info.data,
anonymize_snippets=ClassifierConstants.anonymize_snippets.value,
)
doc_info.topics = topics
doc_info.entities = entities
doc_info.entityDetails = entity_details
doc_info.topicCount = topic_count
doc_info.entityCount = entity_count
doc_info.topicDetails = topic_details
doc_info.data = anonymized_doc
if self.classifier_mode and self.classifier_mode in [
ClassificationMode.ALL.value,
ClassificationMode.TOPIC.value,
]:
topics, topic_count, topic_details = topic_classifier_obj.predict(
doc_info.data
)
doc_info.topics = topics
doc_info.topicCount = topic_count
doc_info.topicDetails = topic_details
if self.classifier_mode and self.classifier_mode in [
ClassificationMode.ALL.value,
ClassificationMode.ENTITY.value,
]:
(
entities,
entity_count,
anonymized_doc,
entity_details,
) = self.entity_classifier_obj.presidio_entity_classifier_and_anonymizer(
doc_info.data,
anonymize_snippets=ClassifierConstants.anonymize_snippets.value,
)
doc_info.entities = entities
doc_info.entityDetails = entity_details
doc_info.entityCount = entity_count
doc_info.data = anonymized_doc
logger.debug("Doc classification finished.")
return doc_info
except Exception as e:
Expand Down Expand Up @@ -270,6 +282,13 @@ def process_request(self, data):
self.data = data
self.app_name = data.get("name")

if not self.data.get("classifier_mode"):
self.classifier_mode = config_details.get("classifier", {}).get(
"mode", ClassificationMode.ALL.value
)
else:
self.classifier_mode = self.data.get("classifier_mode")

# create session
self.db.create_session()

Expand Down
34 changes: 27 additions & 7 deletions tests/app/config/test_config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,33 +122,52 @@ def test_reports_config_validate(setup_and_teardown):

def test_classifier_config_validate():
# Test with True value
config = {"anonymizeSnippets": True}
config = {"mode": "all", "anonymizeSnippets": True}
validator = ClassifierConfig(config)
validator.validate()
assert validator.errors == []

# Test with False value
config = {"anonymizeSnippets": False}
# Test with anonymizeSnippets False value
config = {"mode": "all", "anonymizeSnippets": False}
validator = ClassifierConfig(config)
validator.validate()
assert validator.errors == []

# Test with invalid int
config = {"anonymizeSnippets": 70000}
# Test with mode entity value
config = {"mode": "entity", "anonymizeSnippets": False}
validator = ClassifierConfig(config)
validator.validate()
assert validator.errors == []

# Test with mode topic value
config = {"mode": "topic", "anonymizeSnippets": False}
validator = ClassifierConfig(config)
validator.validate()
assert validator.errors == []

# Test with invalid anonymizeSnippets values
config = {"mode": "all", "anonymizeSnippets": 70000}
validator = ClassifierConfig(config)
validator.validate()
assert validator.errors == [
"Error: Invalid anonymizeSnippets '70000'. anonymizeSnippets must be a boolean."
]

# Test with invalid str
config = {"anonymizeSnippets": "abc"}
config = {"mode": "all", "anonymizeSnippets": "abc"}
validator = ClassifierConfig(config)
validator.validate()
assert validator.errors == [
"Error: Invalid anonymizeSnippets 'abc'. anonymizeSnippets must be a boolean."
]

# Test with invalid mode values
config = {"mode": "Wrong", "anonymizeSnippets": True}
validator = ClassifierConfig(config)
validator.validate()
assert validator.errors == [
"Error: Unsupported classifier mode 'Wrong' specified in the configuration. Valid values are ['all', 'entity', 'topic']"
]


def test_validate_config(setup_and_teardown):
# Test with valid configuration
Expand All @@ -161,6 +180,7 @@ def test_validate_config(setup_and_teardown):
"cacheDir": "~/.pebblo_test_",
},
"classifier": {
"mode": "all",
"anonymizeSnippets": True,
},
"storage": {"type": "file"},
Expand Down
2 changes: 1 addition & 1 deletion tests/app/service/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_process_request_success(mock_entity_classifier, mock_topic_classifier):
cls_obj = Classification(data)
response = cls_obj.process_request()
expected_response = AiDataModel(
data="",
data=data["data"],
entities={"us-ssn": 1},
entityCount=1,
entityDetails={
Expand Down
2 changes: 1 addition & 1 deletion tests/log/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
daemon=PortConfig(host="localhost", port=8000),
logging=LoggingConfig(),
reports=ReportConfig(format="pdf", renderer="xhtml2pdf", cacheDir="~/.pebblo"),
classifier=ClassifierConfig(anonymizeSnippets=False),
classifier=ClassifierConfig(mode="all", anonymizeSnippets=False),
storage=StorageConfig(type="db"),
)
var_server_config.set(config)

0 comments on commit 6ebdf6e

Please sign in to comment.