From 3ebfb2a8b4d12fe430d5f3a2226adba8f2b4a725 Mon Sep 17 00:00:00 2001 From: Kevin Ferm Date: Thu, 14 Mar 2024 10:38:31 +0100 Subject: [PATCH] Use a cors middleware for all API routes --- sim/web/main.go | 39 ++++++++++++++++++--------------------- ui/worker/local_worker.js | 4 ++-- vite.config.js | 8 ++++---- 3 files changed, 24 insertions(+), 27 deletions(-) diff --git a/sim/web/main.go b/sim/web/main.go index ee78b1bb11..5649fb2a14 100644 --- a/sim/web/main.go +++ b/sim/web/main.go @@ -219,13 +219,11 @@ func (s *server) handleAsyncAPI(w http.ResponseWriter, r *http.Request) { func (s *server) setupAsyncServer() { // All async handlers here will call the addNewSim, generating a new UUID and cached progress state. for route := range asyncAPIHandlers { - http.HandleFunc(route, func(w http.ResponseWriter, r *http.Request) { - s.handleAsyncAPI(w, r) - }) + http.Handle(route, corsMiddleware(http.HandlerFunc(s.handleAsyncAPI))) } // asyncProgress will fetch the current progress of a simulation by its UUID. - http.HandleFunc("/asyncProgress", func(w http.ResponseWriter, r *http.Request) { + http.Handle("/asyncProgress", corsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { return @@ -261,9 +259,20 @@ func (s *server) setupAsyncServer() { } w.Header().Add("Content-Type", "application/x-protobuf") w.Write(outbytes) + }))) +} +func corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") + w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization") + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + next.ServeHTTP(w, r) }) } - func (s *server) runServer(useFS bool, host string, launchBrowser bool, simName string, wasm bool, inputReader *bufio.Reader) { s.setupAsyncServer() @@ -277,7 +286,7 @@ func (s *server) runServer(useFS bool, host string, launchBrowser bool, simName } for route := range handlers { - http.HandleFunc(route, handleAPI) + http.Handle(route, corsMiddleware(http.HandlerFunc(handleAPI))) } http.HandleFunc("/version", func(resp http.ResponseWriter, req *http.Request) { @@ -388,22 +397,10 @@ func (s *server) runServer(useFS bool, host string, launchBrowser bool, simName } } -func enableCors(w *http.ResponseWriter) { - (*w).Header().Set("Access-Control-Allow-Origin", "*") // Allow any domain. Adjust as necessary. - (*w).Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") // Adjust the methods based on your requirements. - (*w).Header().Set("Access-Control-Allow-Headers", "*") // Allow any headers. Adjust as necessary. - -} - // handleAPI is generic handler for any api function using protos. func handleAPI(w http.ResponseWriter, r *http.Request) { endpoint := r.URL.Path - enableCors(&w) - // Handle preflight requests - if r.Method == "OPTIONS" { - w.WriteHeader(http.StatusOK) - return - } + body, err := io.ReadAll(r.Body) if err != nil { return @@ -421,13 +418,13 @@ func handleAPI(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) return } - + if googleProto.Equal(msg, msg.ProtoReflect().New().Interface()) { log.Printf("Request is empty") w.WriteHeader(http.StatusBadRequest) return } - + result := handler.handle(msg) outbytes, err := googleProto.Marshal(result) diff --git a/ui/worker/local_worker.js b/ui/worker/local_worker.js index 30306f7031..2aefa7e1d3 100644 --- a/ui/worker/local_worker.js +++ b/ui/worker/local_worker.js @@ -1,6 +1,6 @@ var workerID = ""; -addEventListener('message', async (e) => { +addEventListener('message', async e => { const msg = e.data.msg; const id = e.data.id; @@ -22,7 +22,7 @@ addEventListener('message', async (e) => { var outputData; if (msg == "raidSimAsync" || msg == "statWeightsAsync" || msg == "bulkSimAsync") { while (true) { - let progressResponse = await fetch("/asyncProgress", { + let progressResponse = await fetch("http://localhost:3333/asyncProgress", { method: 'POST', headers: { 'Content-Type': 'application/x-protobuf' diff --git a/vite.config.js b/vite.config.js index c70c5ff707..56d2ecefad 100644 --- a/vite.config.js +++ b/vite.config.js @@ -1,7 +1,7 @@ -import path from "path"; +import fs from 'fs'; import glob from "glob"; +import path from "path"; import { defineConfig } from 'vite' -import fs from 'fs'; function serveExternalAssets() { return { @@ -11,7 +11,7 @@ function serveExternalAssets() { const workerMappings = { '/cata/sim_worker.js': '/cata/local_worker.js', '/cata/net_worker.js': '/cata/net_worker.js', - '/cata/lib.wasm': '/cata/lib.wasm', + '/cata/lib.wasm': '/cata/lib.wasm', }; if (Object.keys(workerMappings).includes(req.url)) { @@ -43,7 +43,7 @@ function serveFile(res, filePath) { res.writeHead(200, { 'Content-Type': contentType }); fs.createReadStream(filePath).pipe(res); } else { - console.log(filePath) + console.log("Not found on filesystem: ", filePath) res.writeHead(404, { 'Content-Type': 'text/plain' }); res.end('Not Found'); }