diff --git a/src/app/__init__.py b/src/app/__init__.py index 9f44074..24809fb 100644 --- a/src/app/__init__.py +++ b/src/app/__init__.py @@ -5,6 +5,7 @@ from flask_limiter.util import get_remote_address from flask_talisman import Talisman from . import csp +from os import environ limiter = Limiter( key_func=get_remote_address, default_limits=["2 per second", "1000 per day"] @@ -24,6 +25,7 @@ def create_app(): app, content_security_policy=csp.csp, content_security_policy_nonce_in=["script-src"], + force_https=not environ.get("FLASK_ENV") == "development", ) from . import config, routes diff --git a/src/app/routes.py b/src/app/routes.py index 9f73284..455c6c1 100644 --- a/src/app/routes.py +++ b/src/app/routes.py @@ -44,8 +44,19 @@ def fluff_results(): sql = "\n".join(sql.splitlines()) + "\n" dialect = request.args["dialect"] - linted = lint(sql, dialect=dialect) - fixed_sql = fix(sql, dialect=dialect) + try: + linted = lint(sql, dialect=dialect) + fixed_sql = fix(sql, dialect=dialect) + except RuntimeError as e: + linted = [ + { + "start_line_no": 1, + "start_line_pos": 1, + "code": "RuntimeError", + "description": str(e), + } + ] + fixed_sql = sql return render_template( "index.html", results=True, diff --git a/test/test_app.py b/test/test_app.py index 4001ec1..2ff7340 100644 --- a/test/test_app.py +++ b/test/test_app.py @@ -4,6 +4,7 @@ import pytest from app.routes import sql_encode from bs4 import BeautifulSoup +from unittest.mock import patch @pytest.fixture @@ -95,6 +96,20 @@ def test_newlines_in_error(client): ) +@patch("app.routes.lint") +def test_runtime_error(mock_lint, client): + """Test that a runtime error is handled.""" + mock_lint.side_effect = RuntimeError("This is a test error") + sql_encoded = sql_encode("select * from table") + rv = client.get("/fluffed", query_string=f"""dialect=ansi&sql={sql_encoded}""") + html = rv.data.decode().lower() + assert "sqlfluff online" in html + assert "fixed sql" in html + assert "select * from table" in html + assert "runtimeerror" in html + assert "this is a test error" in html + + def test_security_headers(client): """Test flask-talisman is setting the security headers""" rv = client.get("/")