diff --git a/CHANGELOG.md b/CHANGELOG.md index 49266dcf..1ceda736 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- `pw.io.http.rest_connector` can now handle different kinds of HTTP requests. +- `pw.io.http.PathwayWebserver` can now enable CORS on the added endpoints. + ### Fixed - Returning `pw.Duration` from UDFs or using them as constant values no longer results in errors. diff --git a/integration_tests/webserver/test_rest_connector.py b/integration_tests/webserver/test_rest_connector.py index 2f71116a..38d38f46 100644 --- a/integration_tests/webserver/test_rest_connector.py +++ b/integration_tests/webserver/test_rest_connector.py @@ -218,7 +218,7 @@ class InputSchema(pw.Schema): pw.run() -def test_server_two_endpoints(tmp_path: pathlib.Path, port: int): +def _test_server_two_endpoints(tmp_path: pathlib.Path, port: int, with_cors: bool): output_path = tmp_path / "output.csv" class InputSchema(pw.Schema): @@ -243,10 +243,7 @@ def target(): ) r.raise_for_status() assert r.text == '"oneone"', r.text - r = requests.post( - f"http://127.0.0.1:{port}/duplicate", - json={"query": "two", "user": "sergey"}, - ) + r = requests.get(f"http://127.0.0.1:{port}/duplicate?query=two&user=sergey") r.raise_for_status() assert r.text == '"twotwo"', r.text r = requests.post( @@ -255,19 +252,24 @@ def target(): ) r.raise_for_status() assert r.text == '"ONE"', r.text - r = requests.post( - f"http://127.0.0.1:{port}/uppercase", - json={"query": "two", "user": "sergey"}, - ) + r = requests.get(f"http://127.0.0.1:{port}/uppercase?query=two&user=sergey") r.raise_for_status() assert r.text == '"TWO"', r.text - webserver = pw.io.http.PathwayWebserver(host="127.0.0.1", port=port) + webserver = pw.io.http.PathwayWebserver( + host="127.0.0.1", + port=port, + with_cors=with_cors, + ) uppercase_queries, uppercase_response_writer = pw.io.http.rest_connector( webserver=webserver, schema=InputSchema, route="/uppercase", + methods=( + "GET", + "POST", + ), delete_completed_queries=True, ) uppercase_responses = uppercase_logic(uppercase_queries) @@ -277,6 +279,10 @@ def target(): webserver=webserver, schema=InputSchema, route="/duplicate", + methods=( + "GET", + "POST", + ), delete_completed_queries=True, ) duplicate_responses = duplicate_logic(duplicate_queries) @@ -292,6 +298,14 @@ def target(): wait_result_with_checker(CsvLinesNumberChecker(output_path, 8), 30) +def test_server_two_endpoints_without_cors(tmp_path: pathlib.Path, port: int): + _test_server_two_endpoints(tmp_path, port, with_cors=False) + + +def test_server_two_endpoints_with_cors(tmp_path: pathlib.Path, port: int): + _test_server_two_endpoints(tmp_path, port, with_cors=True) + + def test_server_schema_generation_via_endpoint(port: int): class InputSchema(pw.Schema): query: str diff --git a/pyproject.toml b/pyproject.toml index d3f142c3..8aa54fac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "jupyter_bokeh >= 3.0.7", "jmespath >= 1.0.1", "Office365-REST-Python-Client >= 2.5.3", + "aiohttp_cors >= 0.7.0", ] [project.optional-dependencies] diff --git a/python/pathway/io/http/_server.py b/python/pathway/io/http/_server.py index 546879d5..2d7666f1 100644 --- a/python/pathway/io/http/_server.py +++ b/python/pathway/io/http/_server.py @@ -5,10 +5,11 @@ import logging import threading from collections.abc import Callable -from typing import Any +from typing import Any, Sequence from uuid import uuid4 from warnings import warn +import aiohttp_cors import yaml from aiohttp import web @@ -56,6 +57,8 @@ class PathwayWebserver: with_schema_endpoint: If set to True, the server will also provide ``/_schema`` \ endpoint containing Open API 3.0.3 schema for the handlers generated with \ ``pw.io.http.rest_connector`` calls. + with_cors: If set to True, the server will allow cross-origin requests on the \ +added endpoints. """ _host: str @@ -65,13 +68,18 @@ class PathwayWebserver: _app: web.Application _is_launched: bool - def __init__(self, host, port, with_schema_endpoint=True): + def __init__(self, host, port, with_schema_endpoint=True, with_cors=False): self._host = host self._port = port self._tasks = {} self._loop = asyncio.new_event_loop() self._app = web.Application() + self._registered_routes = {} + if with_cors: + self._cors = aiohttp_cors.setup(self._app) + else: + self._cors = None self._is_launched = False self._app_start_mutex = threading.Lock() self._openapi_description = { @@ -84,7 +92,27 @@ def __init__(self, host, port, with_schema_endpoint=True): "servers": [{"url": f"http://{host}:{port}/"}], } if with_schema_endpoint: - self._app.add_routes([web.get("/_schema", self._schema_handler)]) + self._add_endpoint_to_app("GET", "/_schema", self._schema_handler) + + def _add_endpoint_to_app(self, method, route, handler): + if route not in self._registered_routes: + app_resource = self._app.router.add_resource(route) + if self._cors is not None: + app_resource = self._cors.add(app_resource) + self._registered_routes[route] = app_resource + + app_resource_endpoint = self._registered_routes[route].add_route( + method, handler + ) + if self._cors is not None: + self._cors.add( + app_resource_endpoint, + { + "*": aiohttp_cors.ResourceOptions( + expose_headers="*", allow_headers="*" + ) + }, + ) async def _schema_handler(self, request: web.Request): format = request.query.get("format", "yaml") @@ -123,6 +151,23 @@ def _construct_openapi_plaintext_schema(self, schema) -> dict: return description + def _construct_openapi_get_request_schema(self, schema) -> list: + parameters = [] + for name, props in schema.columns().items(): + field_description = { + "in": "query", + "name": name, + "required": not props.has_default_value(), + } + openapi_type = _ENGINE_TO_OPENAPI_TYPE.get(props.dtype.map_to_engine()) + if openapi_type: + field_description["schema"] = { + "type": openapi_type, + } + parameters.append(field_description) + + return parameters + def _construct_openapi_json_schema(self, schema) -> dict: properties = {} required = [] @@ -161,38 +206,51 @@ def _construct_openapi_json_schema(self, schema) -> dict: return result - def _register_endpoint(self, route, handler, format, schema) -> None: - self._app.add_routes([web.post(route, handler)]) + def _register_endpoint(self, route, handler, format, schema, methods) -> None: + simple_responses_dict = { + "200": { + "description": "OK", + }, + "400": { + "description": "The request is incorrect. Please check if " + "it complies with the auto-generated and Pathway input " + "table schemas" + }, + } - content = {} - if format == "raw": - content["text/plain"] = { - "schema": self._construct_openapi_plaintext_schema(schema) - } - elif format == "custom": - content["application/json"] = { - "schema": self._construct_openapi_json_schema(schema) - } - else: - raise ValueError(f"Unknown endpoint input format: {format}") + self._openapi_description["paths"][route] = {} # type: ignore[index] + for method in methods: + self._add_endpoint_to_app(method, route, handler) + + if method == "GET": + content = { + "parameters": self._construct_openapi_get_request_schema(schema), + "responses": simple_responses_dict, + } + elif format == "raw": + content = { + "text/plain": { + "schema": self._construct_openapi_plaintext_schema(schema) + } + } + elif format == "custom": + content = { + "application/json": { + "schema": self._construct_openapi_json_schema(schema) + } + } + else: + raise ValueError(f"Unknown endpoint input format: {format}") - self._openapi_description["paths"][route] = { # type: ignore[index] - "post": { - "requestBody": { - "content": content, - }, - "responses": { - "200": { - "description": "OK", - }, - "400": { - "description": "The request is incorrect. Please check if " - "it complies with the auto-generated and Pathway input " - "table schemas" + if method != "GET": + self._openapi_description["paths"][route][method.lower()] = { # type: ignore[index] + "requestBody": { + "content": content, }, - }, - } - } + "responses": simple_responses_dict, + } + else: + self._openapi_description["paths"][route]["get"] = content # type: ignore[index] def _run(self) -> None: self._app_start_mutex.acquire() @@ -234,6 +292,7 @@ def __init__( self, webserver: PathwayWebserver, route: str, + methods: Sequence[str], schema: type[pw.Schema], delete_completed_queries: bool, format: str = "raw", @@ -245,7 +304,7 @@ def __init__( self._delete_completed_queries = delete_completed_queries self._format = format - webserver._register_endpoint(route, self.handle, format, schema) + webserver._register_endpoint(route, self.handle, format, schema, methods) def run(self): self._webserver._run() @@ -258,12 +317,12 @@ async def handle(self, request: web.Request): elif self._format == "custom": try: payload = await request.json() - query_params = request.query - for param, value in query_params.items(): - if param not in payload: - payload[param] = value except json.decoder.JSONDecodeError: - raise web.HTTPBadRequest(reason="payload is not a valid json") + payload = {} + query_params = request.query + for param, value in query_params.items(): + if param not in payload: + payload[param] = value self._verify_payload(payload) @@ -308,6 +367,7 @@ def rest_connector( webserver: PathwayWebserver | None = None, route: str = "/", schema: type[pw.Schema] | None = None, + methods: Sequence[str] = ("POST",), autocommit_duration_ms=1500, keep_queries: bool | None = None, delete_completed_queries: bool | None = None, @@ -325,6 +385,7 @@ def rest_connector( need to create only one instance of this class per single host-port pair; route: route which will be listened to by the web server; schema: schema of the resulting table; + methods: HTTP methods that this endpoint will accept; autocommit_duration_ms: the maximum time between two commits. Every autocommit_duration_ms milliseconds, the updates received by the connector are committed and pushed into Pathway's computation graph; @@ -431,6 +492,7 @@ def rest_connector( subject=RestServerSubject( webserver=webserver, route=route, + methods=methods, schema=schema, delete_completed_queries=delete_completed_queries, format=format, diff --git a/python/pathway/tests/test_openapi_schema_generation.py b/python/pathway/tests/test_openapi_schema_generation.py index 5fcb3bca..1b67c299 100644 --- a/python/pathway/tests/test_openapi_schema_generation.py +++ b/python/pathway/tests/test_openapi_schema_generation.py @@ -132,3 +132,41 @@ def test_no_routes(): webserver = pw.io.http.PathwayWebserver(host="127.0.0.1", port=8080) description = webserver.openapi_description_json openapi_spec_validator.validate(description) + + +def test_several_methods(): + class InputSchema(pw.Schema): + k: int + v: str = pw.column_definition(default_value="hello") + + webserver = pw.io.http.PathwayWebserver(host="127.0.0.1", port=8080) + pw.io.http.rest_connector( + webserver=webserver, + methods=("GET", "POST"), + schema=InputSchema, + delete_completed_queries=False, + ) + + description = webserver.openapi_description_json + openapi_spec_validator.validate(description) + + assert set(description["paths"]["/"].keys()) == set(["get", "post"]) + + +def test_all_methods(): + class InputSchema(pw.Schema): + k: int + v: str = pw.column_definition(default_value="hello") + + webserver = pw.io.http.PathwayWebserver(host="127.0.0.1", port=8080) + pw.io.http.rest_connector( + webserver=webserver, + methods=("GET", "POST", "PUT", "PATCH"), + schema=InputSchema, + delete_completed_queries=False, + ) + + description = webserver.openapi_description_json + openapi_spec_validator.validate(description) + + assert set(description["paths"]["/"].keys()) == set(["get", "post", "put", "patch"]) diff --git a/python/pathway/xpacks/llm/vector_store.py b/python/pathway/xpacks/llm/vector_store.py index 2b4d71f9..e75ac0d5 100644 --- a/python/pathway/xpacks/llm/vector_store.py +++ b/python/pathway/xpacks/llm/vector_store.py @@ -359,6 +359,7 @@ def serve(route, schema, handler): queries, writer = pw.io.http.rest_connector( webserver=webserver, route=route, + methods=("GET", "POST"), schema=schema, autocommit_duration_ms=50, delete_completed_queries=True,