Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/keycloak timeout #445

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion src/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from contextlib import asynccontextmanager
import os

import requests # type: ignore

import uvicorn # type: ignore
Comment on lines +5 to 6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import uvicorn # type: ignore
import uvicorn # type: ignore

from pathlib import Path
from fastapi import (
Expand Down Expand Up @@ -44,6 +46,8 @@
BackendStatusDatasetsSchema,
AgentSchema,
ServerSchema,
LoginSchema,
TokenSchema,
)


Expand All @@ -70,6 +74,24 @@ def with_plugins() -> Iterable[PluginManager]:
plugins.cleanup()


def get_new_tokens(refresh_token):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add typing (I guess get_new_tokens(refresh_token: str) -> Tuple[str, str], but recheck)

data = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": db.config.openid_client_id,
"client_secret": db.config.openid_secret,
}
url = "http://mquery-keycloak-1:8080/auth/realms/myrealm/protocol/openid-connect/token"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be hardcoded - use url from config instead

try:
response: requests.Response = requests.post(url=url, data=data)
token_data = response.json()
new_refresh_token = token_data["refresh_token"]
new_token = token_data["access_token"]
return new_token, new_refresh_token
except requests.exceptions.RequestException:
return None, None


class User:
def __init__(self, token: Optional[Dict]) -> None:
self.__token = token
Expand Down Expand Up @@ -124,8 +146,11 @@ async def current_user(authorization: Optional[str] = Header(None)) -> User:
token_json = jwt.decode(
token, public_key, algorithms=["RS256"], audience="account" # type: ignore
)
except jwt.ExpiredSignatureError:
# token expired so user is anonymous
return User(None)
except jwt.InvalidTokenError:
# Invalid token means invalid signature, issuer, or just expired.
# Invalid token means invalid signature, issuer.
raise unauthorized

return User(token_json)
Expand Down Expand Up @@ -584,6 +609,35 @@ def server() -> ServerSchema:
)


@app.post("/api/login", response_model=LoginSchema, tags=["stable"])
async def login(request: Request, response: Response) -> LoginSchema:
token = await request.json()
if token["refresh_token"]:
response.set_cookie(
key="refresh_token",
value=token["refresh_token"],
httponly=True,
max_age=1800,
)
return LoginSchema(status="OK")
return LoginSchema(status="Bad Token")


@app.post("/api/token/refresh", response_model=TokenSchema)
def refresh_token(request: Request, response: Response) -> TokenSchema:
refresh_token_value = request.cookies.get("refresh_token")
if refresh_token_value:
new_token, new_refresh_token = get_new_tokens(refresh_token_value)
response.set_cookie(
key="refresh_token",
value=new_refresh_token,
httponly=True,
max_age=1800,
)
return TokenSchema(token=new_token)
return TokenSchema(token=None)


@app.get("/query/{path}", include_in_schema=False)
def serve_index(path: str) -> FileResponse:
return FileResponse(Path(__file__).parent / "mqueryfront/dist/index.html")
Expand Down
25 changes: 19 additions & 6 deletions src/mqueryfront/src/App.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React, { useState, useEffect } from "react";
import React, { useState, useRef, useEffect } from "react";
import { Routes, Route } from "react-router-dom";
import Navigation from "./Navigation";
import QueryPage from "./query/QueryPage";
Expand All @@ -9,6 +9,7 @@ import AboutPage from "./about/AboutPage";
import AuthPage from "./auth/AuthPage";
import api, { parseJWT } from "./api";
import "./App.css";
import { refreshAccesToken, storeTokenData, clearTokenData } from "./utils";

function getCurrentTokenOrNull() {
// This function handles missing and corrupted token in the same way.
Expand All @@ -21,20 +22,32 @@ function getCurrentTokenOrNull() {

function App() {
const [config, setConfig] = useState(null);
const tokenIntervalRef = useRef(null);

useEffect(() => {
api.get("/server").then((response) => {
setConfig(response.data);
});
tokenIntervalRef.current = setInterval(() => {
refreshAccesToken();
}, 900000); // refresh token every 15 minutes just in case user was idle.
return () => clearInterval(tokenIntervalRef.current);
}, []);

const login = (rawToken) => {
localStorage.setItem("rawToken", rawToken);
window.location.href = "/";
const login = async (token_data) => {
token_data.not_before_policy = token_data["not-before-policy"];
delete token_data["not-before-policy"];
const response = await api.post("/login", token_data);
storeTokenData(token_data["access_token"]);
const location_href = localStorage.getItem("currentLocation");
if (location_href) {
window.location.href = location_href;
} else {
window.location.href = "/";
}
};

const logout = () => {
localStorage.removeItem("rawToken");
clearTokenData(tokenIntervalRef.current);
if (config !== null) {
const logout_url = new URL(config["openid_url"] + "/logout");
logout_url.searchParams.append(
Expand Down
8 changes: 7 additions & 1 deletion src/mqueryfront/src/api.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import axios from "axios";
import { refreshAccesToken, tokenExpired } from "./utils";

export const api_url = "/api";

Expand All @@ -8,7 +9,11 @@ export function parseJWT(token) {
return JSON.parse(atob(base64));
}

function request(method, path, payload, params) {
async function request(method, path, payload, params) {
if (tokenExpired()) {
// If the token expired, try to refresh it with refresh_token
await refreshAccesToken();
}
const rawToken = localStorage.getItem("rawToken");
const headers = rawToken ? { Authorization: `Bearer ${rawToken}` } : {};
return axios
Expand All @@ -17,6 +22,7 @@ function request(method, path, payload, params) {
data: payload,
params: params,
headers: headers,
withCredentials: true,
})
.catch((error) => {
if (error.response.status === 401) {
Expand Down
2 changes: 1 addition & 1 deletion src/mqueryfront/src/auth/AuthPage.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class AuthPage extends Component {
axios
.post(this.props.config["openid_url"] + "/token", params)
.then((response) => {
this.props.login(response.data["access_token"]);
this.props.login(response.data);
})
.catch((error) => {
this.setState({ error: error });
Expand Down
1 change: 1 addition & 0 deletions src/mqueryfront/src/config/ConfigPage.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class ConfigPage extends Component {
}

componentDidMount() {
localStorage.setItem("currentLocation", window.location.href);
api.get("/config")
.then((response) => {
this.setState({ config: response.data });
Expand Down
1 change: 1 addition & 0 deletions src/mqueryfront/src/query/QueryPage.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class QueryPageInner extends Component {
}

async componentDidMount() {
localStorage.setItem("currentLocation", window.location.href);
if (this.queryHash) {
this.fetchJob();
}
Expand Down
1 change: 1 addition & 0 deletions src/mqueryfront/src/recent/RecentPage.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class RecentPage extends Component {
}

componentDidMount() {
localStorage.setItem("currentLocation", window.location.href);
api.get("/job")
.then((response) => {
const { jobs } = response.data;
Expand Down
1 change: 1 addition & 0 deletions src/mqueryfront/src/status/StatusPage.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class StatusPage extends Component {
}

componentDidMount() {
localStorage.setItem("currentLocation", window.location.href);
api.get("/backend")
.then((response) => {
this.setState({ backend: response.data });
Expand Down
44 changes: 44 additions & 0 deletions src/mqueryfront/src/utils.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import axios from "axios";
import api, { parseJWT } from "./api";
export const isStatusFinished = (status) =>
["done", "cancelled"].includes(status);

Expand Down Expand Up @@ -25,3 +27,45 @@ export const openidLoginUrl = (config) => {
);
return login_url;
};

export const storeTokenData = (token) => {
localStorage.setItem("rawToken", token);
const decodedToken = parseJWT(token);
localStorage.setItem("expiresAt", decodedToken.exp * 1000);
};

export const refreshAccesToken = async () => {
const rawToken = localStorage.getItem("rawToken");
const expiresAt = localStorage.getItem("expiresAt");
if (rawToken) {
const headers = rawToken ? { Authorization: `Bearer ${rawToken}` } : {};
const response = await axios.request("/api/token/refresh", {
method: "POST",
headers: headers,
withCredentials: true,
});
if (response.data["token"]) {
storeTokenData(response.data["token"]);
} else {
return;
}
}
};

export const clearTokenData = (tokenInterval) => {
clearInterval(tokenInterval);
localStorage.removeItem("expiresAt");
localStorage.removeItem("rawToken");
};

export const tokenExpired = () => {
const rawToken = localStorage.getItem("rawToken");
if (rawToken) {
const expiresAt = localStorage.getItem("expiresAt");
if (Date.now() > expiresAt) {
return true;
}
return false;
}
return false;
};
8 changes: 8 additions & 0 deletions src/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,11 @@ class ServerSchema(BaseModel):
openid_url: Optional[str]
openid_client_id: Optional[str]
about: str


class LoginSchema(BaseModel):
status: str


class TokenSchema(BaseModel):
token: str | None
Loading