Skip to content

Commit

Permalink
support request type and CORS configuration in rest endpoint (#5518)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 86d6038f0fd4cda34adb03b9d61362d54ec8cb7c
  • Loading branch information
zxqfd555-pw authored and Manul from Pathway committed Jan 29, 2024
1 parent 32042a5 commit 56e2eb0
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 48 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]

### Added
- `pw.io.http.rest_connector` can now handle different kinds of HTTP requests.
- `pw.io.http.PathwayWebserver` can now enable CORS on the added endpoints.

### Fixed
- Returning `pw.Duration` from UDFs or using them as constant values no longer results in errors.

Expand Down
34 changes: 24 additions & 10 deletions integration_tests/webserver/test_rest_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class InputSchema(pw.Schema):
pw.run()


def test_server_two_endpoints(tmp_path: pathlib.Path, port: int):
def _test_server_two_endpoints(tmp_path: pathlib.Path, port: int, with_cors: bool):
output_path = tmp_path / "output.csv"

class InputSchema(pw.Schema):
Expand All @@ -243,10 +243,7 @@ def target():
)
r.raise_for_status()
assert r.text == '"oneone"', r.text
r = requests.post(
f"http://127.0.0.1:{port}/duplicate",
json={"query": "two", "user": "sergey"},
)
r = requests.get(f"http://127.0.0.1:{port}/duplicate?query=two&user=sergey")
r.raise_for_status()
assert r.text == '"twotwo"', r.text
r = requests.post(
Expand All @@ -255,19 +252,24 @@ def target():
)
r.raise_for_status()
assert r.text == '"ONE"', r.text
r = requests.post(
f"http://127.0.0.1:{port}/uppercase",
json={"query": "two", "user": "sergey"},
)
r = requests.get(f"http://127.0.0.1:{port}/uppercase?query=two&user=sergey")
r.raise_for_status()
assert r.text == '"TWO"', r.text

webserver = pw.io.http.PathwayWebserver(host="127.0.0.1", port=port)
webserver = pw.io.http.PathwayWebserver(
host="127.0.0.1",
port=port,
with_cors=with_cors,
)

uppercase_queries, uppercase_response_writer = pw.io.http.rest_connector(
webserver=webserver,
schema=InputSchema,
route="/uppercase",
methods=(
"GET",
"POST",
),
delete_completed_queries=True,
)
uppercase_responses = uppercase_logic(uppercase_queries)
Expand All @@ -277,6 +279,10 @@ def target():
webserver=webserver,
schema=InputSchema,
route="/duplicate",
methods=(
"GET",
"POST",
),
delete_completed_queries=True,
)
duplicate_responses = duplicate_logic(duplicate_queries)
Expand All @@ -292,6 +298,14 @@ def target():
wait_result_with_checker(CsvLinesNumberChecker(output_path, 8), 30)


def test_server_two_endpoints_without_cors(tmp_path: pathlib.Path, port: int):
_test_server_two_endpoints(tmp_path, port, with_cors=False)


def test_server_two_endpoints_with_cors(tmp_path: pathlib.Path, port: int):
_test_server_two_endpoints(tmp_path, port, with_cors=True)


def test_server_schema_generation_via_endpoint(port: int):
class InputSchema(pw.Schema):
query: str
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
"jupyter_bokeh >= 3.0.7",
"jmespath >= 1.0.1",
"Office365-REST-Python-Client >= 2.5.3",
"aiohttp_cors >= 0.7.0",
]

[project.optional-dependencies]
Expand Down
138 changes: 100 additions & 38 deletions python/pathway/io/http/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import logging
import threading
from collections.abc import Callable
from typing import Any
from typing import Any, Sequence
from uuid import uuid4
from warnings import warn

import aiohttp_cors
import yaml
from aiohttp import web

Expand Down Expand Up @@ -56,6 +57,8 @@ class PathwayWebserver:
with_schema_endpoint: If set to True, the server will also provide ``/_schema`` \
endpoint containing Open API 3.0.3 schema for the handlers generated with \
``pw.io.http.rest_connector`` calls.
with_cors: If set to True, the server will allow cross-origin requests on the \
added endpoints.
"""

_host: str
Expand All @@ -65,13 +68,18 @@ class PathwayWebserver:
_app: web.Application
_is_launched: bool

def __init__(self, host, port, with_schema_endpoint=True):
def __init__(self, host, port, with_schema_endpoint=True, with_cors=False):
self._host = host
self._port = port

self._tasks = {}
self._loop = asyncio.new_event_loop()
self._app = web.Application()
self._registered_routes = {}
if with_cors:
self._cors = aiohttp_cors.setup(self._app)
else:
self._cors = None
self._is_launched = False
self._app_start_mutex = threading.Lock()
self._openapi_description = {
Expand All @@ -84,7 +92,27 @@ def __init__(self, host, port, with_schema_endpoint=True):
"servers": [{"url": f"http://{host}:{port}/"}],
}
if with_schema_endpoint:
self._app.add_routes([web.get("/_schema", self._schema_handler)])
self._add_endpoint_to_app("GET", "/_schema", self._schema_handler)

def _add_endpoint_to_app(self, method, route, handler):
if route not in self._registered_routes:
app_resource = self._app.router.add_resource(route)
if self._cors is not None:
app_resource = self._cors.add(app_resource)
self._registered_routes[route] = app_resource

app_resource_endpoint = self._registered_routes[route].add_route(
method, handler
)
if self._cors is not None:
self._cors.add(
app_resource_endpoint,
{
"*": aiohttp_cors.ResourceOptions(
expose_headers="*", allow_headers="*"
)
},
)

async def _schema_handler(self, request: web.Request):
format = request.query.get("format", "yaml")
Expand Down Expand Up @@ -123,6 +151,23 @@ def _construct_openapi_plaintext_schema(self, schema) -> dict:

return description

def _construct_openapi_get_request_schema(self, schema) -> list:
parameters = []
for name, props in schema.columns().items():
field_description = {
"in": "query",
"name": name,
"required": not props.has_default_value(),
}
openapi_type = _ENGINE_TO_OPENAPI_TYPE.get(props.dtype.map_to_engine())
if openapi_type:
field_description["schema"] = {
"type": openapi_type,
}
parameters.append(field_description)

return parameters

def _construct_openapi_json_schema(self, schema) -> dict:
properties = {}
required = []
Expand Down Expand Up @@ -161,38 +206,51 @@ def _construct_openapi_json_schema(self, schema) -> dict:

return result

def _register_endpoint(self, route, handler, format, schema) -> None:
self._app.add_routes([web.post(route, handler)])
def _register_endpoint(self, route, handler, format, schema, methods) -> None:
simple_responses_dict = {
"200": {
"description": "OK",
},
"400": {
"description": "The request is incorrect. Please check if "
"it complies with the auto-generated and Pathway input "
"table schemas"
},
}

content = {}
if format == "raw":
content["text/plain"] = {
"schema": self._construct_openapi_plaintext_schema(schema)
}
elif format == "custom":
content["application/json"] = {
"schema": self._construct_openapi_json_schema(schema)
}
else:
raise ValueError(f"Unknown endpoint input format: {format}")
self._openapi_description["paths"][route] = {} # type: ignore[index]
for method in methods:
self._add_endpoint_to_app(method, route, handler)

if method == "GET":
content = {
"parameters": self._construct_openapi_get_request_schema(schema),
"responses": simple_responses_dict,
}
elif format == "raw":
content = {
"text/plain": {
"schema": self._construct_openapi_plaintext_schema(schema)
}
}
elif format == "custom":
content = {
"application/json": {
"schema": self._construct_openapi_json_schema(schema)
}
}
else:
raise ValueError(f"Unknown endpoint input format: {format}")

self._openapi_description["paths"][route] = { # type: ignore[index]
"post": {
"requestBody": {
"content": content,
},
"responses": {
"200": {
"description": "OK",
},
"400": {
"description": "The request is incorrect. Please check if "
"it complies with the auto-generated and Pathway input "
"table schemas"
if method != "GET":
self._openapi_description["paths"][route][method.lower()] = { # type: ignore[index]
"requestBody": {
"content": content,
},
},
}
}
"responses": simple_responses_dict,
}
else:
self._openapi_description["paths"][route]["get"] = content # type: ignore[index]

def _run(self) -> None:
self._app_start_mutex.acquire()
Expand Down Expand Up @@ -234,6 +292,7 @@ def __init__(
self,
webserver: PathwayWebserver,
route: str,
methods: Sequence[str],
schema: type[pw.Schema],
delete_completed_queries: bool,
format: str = "raw",
Expand All @@ -245,7 +304,7 @@ def __init__(
self._delete_completed_queries = delete_completed_queries
self._format = format

webserver._register_endpoint(route, self.handle, format, schema)
webserver._register_endpoint(route, self.handle, format, schema, methods)

def run(self):
self._webserver._run()
Expand All @@ -258,12 +317,12 @@ async def handle(self, request: web.Request):
elif self._format == "custom":
try:
payload = await request.json()
query_params = request.query
for param, value in query_params.items():
if param not in payload:
payload[param] = value
except json.decoder.JSONDecodeError:
raise web.HTTPBadRequest(reason="payload is not a valid json")
payload = {}
query_params = request.query
for param, value in query_params.items():
if param not in payload:
payload[param] = value

self._verify_payload(payload)

Expand Down Expand Up @@ -308,6 +367,7 @@ def rest_connector(
webserver: PathwayWebserver | None = None,
route: str = "/",
schema: type[pw.Schema] | None = None,
methods: Sequence[str] = ("POST",),
autocommit_duration_ms=1500,
keep_queries: bool | None = None,
delete_completed_queries: bool | None = None,
Expand All @@ -325,6 +385,7 @@ def rest_connector(
need to create only one instance of this class per single host-port pair;
route: route which will be listened to by the web server;
schema: schema of the resulting table;
methods: HTTP methods that this endpoint will accept;
autocommit_duration_ms: the maximum time between two commits. Every
autocommit_duration_ms milliseconds, the updates received by the connector are
committed and pushed into Pathway's computation graph;
Expand Down Expand Up @@ -431,6 +492,7 @@ def rest_connector(
subject=RestServerSubject(
webserver=webserver,
route=route,
methods=methods,
schema=schema,
delete_completed_queries=delete_completed_queries,
format=format,
Expand Down
38 changes: 38 additions & 0 deletions python/pathway/tests/test_openapi_schema_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,41 @@ def test_no_routes():
webserver = pw.io.http.PathwayWebserver(host="127.0.0.1", port=8080)
description = webserver.openapi_description_json
openapi_spec_validator.validate(description)


def test_several_methods():
class InputSchema(pw.Schema):
k: int
v: str = pw.column_definition(default_value="hello")

webserver = pw.io.http.PathwayWebserver(host="127.0.0.1", port=8080)
pw.io.http.rest_connector(
webserver=webserver,
methods=("GET", "POST"),
schema=InputSchema,
delete_completed_queries=False,
)

description = webserver.openapi_description_json
openapi_spec_validator.validate(description)

assert set(description["paths"]["/"].keys()) == set(["get", "post"])


def test_all_methods():
class InputSchema(pw.Schema):
k: int
v: str = pw.column_definition(default_value="hello")

webserver = pw.io.http.PathwayWebserver(host="127.0.0.1", port=8080)
pw.io.http.rest_connector(
webserver=webserver,
methods=("GET", "POST", "PUT", "PATCH"),
schema=InputSchema,
delete_completed_queries=False,
)

description = webserver.openapi_description_json
openapi_spec_validator.validate(description)

assert set(description["paths"]["/"].keys()) == set(["get", "post", "put", "patch"])
1 change: 1 addition & 0 deletions python/pathway/xpacks/llm/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def serve(route, schema, handler):
queries, writer = pw.io.http.rest_connector(
webserver=webserver,
route=route,
methods=("GET", "POST"),
schema=schema,
autocommit_duration_ms=50,
delete_completed_queries=True,
Expand Down

0 comments on commit 56e2eb0

Please sign in to comment.