Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add no_validate Flag To @router.command #6988

Open
wants to merge 19 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions openbb_platform/core/openbb_core/api/app_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Optional

from fastapi import APIRouter, FastAPI
from fastapi.exceptions import ResponseValidationError
from openbb_core.api.exception_handlers import ExceptionHandlers
from openbb_core.app.model.abstract.error import OpenBBError
from openbb_core.app.router import RouterLoader
Expand Down Expand Up @@ -38,6 +39,7 @@ def add_exception_handlers(app: FastAPI):
"""Add exception handlers."""
app.exception_handlers[Exception] = ExceptionHandlers.exception
app.exception_handlers[ValidationError] = ExceptionHandlers.validation
app.exception_handlers[ResponseValidationError] = ExceptionHandlers.validation
app.exception_handlers[OpenBBError] = ExceptionHandlers.openbb
app.exception_handlers[EmptyDataError] = ExceptionHandlers.empty_data
app.exception_handlers[UnauthorizedError] = ExceptionHandlers.unauthorized
34 changes: 29 additions & 5 deletions openbb_platform/core/openbb_core/api/exception_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import logging
from collections.abc import Iterable
from typing import Any
from typing import Any, Union

from fastapi import Request
from fastapi.exceptions import ResponseValidationError
from fastapi.responses import JSONResponse, Response
from openbb_core.app.model.abstract.error import OpenBBError
from openbb_core.env import Env
Expand Down Expand Up @@ -35,7 +36,8 @@ async def _handle(exception: Exception, status_code: int, detail: Any):
@staticmethod
async def exception(_: Request, error: Exception) -> JSONResponse:
"""Exception handler for Base Exception."""
errors = error.errors(include_url=False) if hasattr(error, "errors") else error
errors = error.errors if hasattr(error, "errors") else error

if errors:
if isinstance(errors, ValueError):
return await ExceptionHandlers._handle(
Expand All @@ -55,19 +57,41 @@ async def exception(_: Request, error: Exception) -> JSONResponse:
return await ExceptionHandlers._handle(
exception=error,
status_code=500,
detail=f"Unexpected Error -> {error.__class__.__name__} -> {str(error.args[0] or error.args)}",
detail=f"Unexpected Error -> {error.__class__.__name__} -> {error}",
)

@staticmethod
async def validation(request: Request, error: ValidationError):
async def validation(
request: Request, error: Union[ValidationError, ResponseValidationError]
):
"""Exception handler for ValidationError."""
# Some validation is performed at Fetcher level.
# So we check if the validation error comes from a QueryParams class.
# And that it is in the request query params.
# If yes, we update the error location with query.
# If not, we handle it as a base Exception error.
query_params = dict(request.query_params)
errors = error.errors(include_url=False)
if isinstance(error, ResponseValidationError):
detail = [
{
**{k: v for k, v in err.items() if k != "ctx"},
"loc": ("query",) + err.get("loc", ()),
}
for err in error.errors()
]
return await ExceptionHandlers._handle(
exception=error,
status_code=422,
detail=detail,
)
try:
errors = (
error.errors(include_url=False)
if hasattr(error, "errors")
else error.errors
)
except Exception:
errors = error.errors if hasattr(error, "errors") else error
all_in_query = all(
loc in query_params for err in errors for loc in err.get("loc", ())
)
Expand Down
13 changes: 11 additions & 2 deletions openbb_platform/core/openbb_core/api/router/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def exclude_fields_from_api(key: str, value: Any):
elif is_model(getattr(field, "annotation", None)):
exclude_fields_from_api(field_name, getattr(value, field_name))

# Let a non-OBBject object pass through without validation
if not isinstance(c_out, OBBject):
piiq marked this conversation as resolved.
Show resolved Hide resolved
return c_out

for k, v in c_out.model_copy():
exclude_fields_from_api(k, v)

Expand All @@ -181,12 +185,15 @@ def build_api_wrapper(
func: Callable = route.endpoint # type: ignore
path: str = route.path # type: ignore

no_validate = route.openapi_extra.get("no_validate")
new_signature = build_new_signature(path=path, func=func)
new_annotations_map = build_new_annotation_map(sig=new_signature)

func.__signature__ = new_signature # type: ignore
func.__annotations__ = new_annotations_map

if no_validate is True:
route.response_model = None

@wraps(wrapped=func)
async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> OBBject:
user_settings: UserSettings = UserSettings.model_validate(
Expand Down Expand Up @@ -238,8 +245,10 @@ async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> OBBject:
kwargs["extra_params"] = extra_params

execute = partial(command_runner.run, path, user_settings)
output: OBBject = await execute(*args, **kwargs)
output = await execute(*args, **kwargs)

if no_validate is True:
return output
return validate_output(output)

return wrapper
Expand Down
88 changes: 60 additions & 28 deletions openbb_platform/core/openbb_core/app/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
from openbb_core.app.model.metadata import Metadata
from openbb_core.app.model.obbject import OBBject
from openbb_core.app.provider_interface import ExtraParams
from openbb_core.app.static.package_builder import PathHandler
from openbb_core.env import Env
from openbb_core.provider.utils.helpers import maybe_coroutine, run_async
from pydantic import BaseModel, ConfigDict, create_model

if TYPE_CHECKING:
from fastapi.routing import APIRoute
from openbb_core.app.model.system_settings import SystemSettings
from openbb_core.app.model.user_settings import UserSettings
from openbb_core.app.router import CommandMap
Expand All @@ -29,6 +31,9 @@
class ExecutionContext:
"""Execution context."""

# For checking if the command specifies no validation in the API Route
_route_map = PathHandler.build_route_map()

def __init__(
self,
command_map: "CommandMap",
Expand All @@ -42,6 +47,11 @@ def __init__(
self.system_settings = system_settings
self.user_settings = user_settings

@property
def api_route(self) -> "APIRoute":
"""API route."""
return self._route_map[self.route]


class ParametersBuilder:
"""Build parameters for a function."""
Expand Down Expand Up @@ -222,7 +232,12 @@ async def _command(
) -> OBBject:
"""Run a command and return the output."""
obbject = await maybe_coroutine(func, **kwargs)
obbject.provider = getattr(kwargs.get("provider_choices"), "provider", None)
if isinstance(obbject, OBBject):
obbject.provider = getattr(
kwargs.get("provider_choices"),
"provider",
getattr(obbject, "provider", None),
)
return obbject

@classmethod
Expand Down Expand Up @@ -314,21 +329,49 @@ async def _execute_func( # pylint: disable=too-many-positional-arguments
for name, default in model_headers.items() or {}
} or None

validate = not execution_context.api_route.openapi_extra.get("no_validate")
try:
obbject = await cls._command(func, kwargs)

# This section prepares the obbject to pass to the charting service.
obbject._route = route # pylint: disable=protected-access
std_params = cls._extract_params(kwargs, "standard_params") or (
kwargs if "data" in kwargs else {}
)
extra_params = cls._extract_params(kwargs, "extra_params")
obbject._standard_params = ( # pylint: disable=protected-access
std_params
)
obbject._extra_params = extra_params # pylint: disable=protected-access
if chart and obbject.results:
cls._chart(obbject, **kwargs)
# The output might be from a router command with 'no_validate=True'
# It might be of a different type than OBBject.
# In this case, we avoid accessing those attributes.
if isinstance(obbject, OBBject) or validate:
if validate and not isinstance(obbject, OBBject):
raise OpenBBError(
TypeError(
f"Expected OBBject instance at function output, got {type(obbject)} instead."
)
)
# This section prepares the obbject to pass to the charting service.
obbject._route = route # pylint: disable=protected-access
std_params = cls._extract_params(kwargs, "standard_params") or (
kwargs if "data" in kwargs else {}
)
extra_params = cls._extract_params(kwargs, "extra_params")
obbject._standard_params = ( # pylint: disable=protected-access
std_params
)
obbject._extra_params = ( # pylint: disable=protected-access
extra_params
)
if chart and obbject.results:
cls._chart(obbject, **kwargs)

if warning_list:
if isinstance(obbject, OBBject):
obbject.warnings = []
for w in warning_list:
if isinstance(obbject, OBBject):
obbject.warnings.append(cast_warning(w))
if user_settings.preferences.show_warnings:
showwarning(
message=w.message,
category=w.category,
filename=w.filename,
lineno=w.lineno,
file=w.file,
line=w.line,
)
finally:
ls = LoggingService(system_settings, user_settings)
ls.log(
Expand All @@ -341,19 +384,6 @@ async def _execute_func( # pylint: disable=too-many-positional-arguments
custom_headers=custom_headers,
)

if warning_list:
obbject.warnings = []
for w in warning_list:
obbject.warnings.append(cast_warning(w))
if user_settings.preferences.show_warnings:
showwarning(
message=w.message,
category=w.category,
filename=w.filename,
lineno=w.lineno,
file=w.file,
line=w.line,
)
return obbject

# pylint: disable=W0718
Expand Down Expand Up @@ -385,7 +415,9 @@ async def run(

duration = perf_counter_ns() - start_ns

if execution_context.user_settings.preferences.metadata:
if execution_context.user_settings.preferences.metadata and isinstance(
obbject, OBBject
):
try:
obbject.extra["metadata"] = Metadata(
arguments=kwargs,
Expand Down
Loading
Loading