diff --git a/scripts/native_client.py b/scripts/native_client.py index b869b757..2f26f713 100755 --- a/scripts/native_client.py +++ b/scripts/native_client.py @@ -117,6 +117,14 @@ async def translate(self, text, src=None, trg=None, *, model=None, pivot=None, h async def download_model(self, model_id, *, update=lambda data: None): return await self.request("DownloadModel", {"modelID": str(model_id)}, update=update) + async def configure(self, *, threads:int = None, cache_size:int = None): + options = {} + if threads is not None: + options["threads"] = int(threads) + if cache_size is not None: + options["cacheSize"] = int(cache_size) + return await self.request("Configure", options) + def first(iterable, *default): """Returns the first value of anything iterable, or throws StopIteration @@ -186,6 +194,8 @@ async def test(): if model["id"] == selected_model["id"] ) + await tl.configure(threads=1, cache_size=0) + # Perform some translations, switching between the models translations = await asyncio.gather( tl.translate("Hello world!", "en", "de"), diff --git a/src/cli/NativeMsgIface.cpp b/src/cli/NativeMsgIface.cpp index 8dbd6726..47d4042a 100644 --- a/src/cli/NativeMsgIface.cpp +++ b/src/cli/NativeMsgIface.cpp @@ -224,6 +224,17 @@ void NativeMsgIface::handleRequest(ListRequest request) { writeResponse(request, modelsJson); } +void NativeMsgIface::handleRequest(ConfigureRequest request) { + marian::bergamot::AsyncService::Config serviceConfig; + serviceConfig.numWorkers = request.threads; + serviceConfig.cacheSize = request.cacheSize; + service_.reset(); + service_ = std::make_shared(serviceConfig); + + QJsonObject response{}; // I don't know... + writeResponse(request, response); +} + void NativeMsgIface::handleRequest(DownloadRequest request) { // Edge case: client issued a DownloadRequest before fetching the list of // remote models because it knows the model ID from a previous run. We still @@ -284,7 +295,7 @@ request_variant NativeMsgIface::parseJsonInput(QByteArray input) { // Define what are mandatory and what are optional request keys static const QStringList mandatoryKeys({"command", "id", "data"}); // Expected in every message - static const QSet commandTypes({"ListModels", "DownloadModel", "Translate"}); + static const QSet commandTypes({"ListModels", "DownloadModel", "Translate", "Configure"}); // Json doesn't have schema validation, so validate here, in place: QString command; int id; @@ -359,13 +370,21 @@ request_variant NativeMsgIface::parseJsonInput(QByteArray input) { ret.id = id; for (auto&& key : mandatoryKeysDownload) { QJsonValueRef val = data[key]; - if (val.isNull()) { + if (val.isNull() || val.isUndefined()) { return MalformedRequest{id, QString("data field key %1 cannot be null!").arg(key)}; } else { ret.modelID = val.toString(); } } return ret; + } else if (command == "Configure") { + ConfigureRequest ret; + ret.id = id; + if (!data["threads"].isUndefined()) + ret.threads = data["threads"].toInt(); + if (!data["cacheSize"].isUndefined()) + ret.cacheSize = data["cacheSize"].toInt(); + return ret; } else { return MalformedRequest{id, QString("Developer error. We shouldn't ever be here! Command: %1").arg(command)}; } diff --git a/src/cli/NativeMsgIface.h b/src/cli/NativeMsgIface.h index 6bad9aa0..88071d3f 100644 --- a/src/cli/NativeMsgIface.h +++ b/src/cli/NativeMsgIface.h @@ -236,6 +236,33 @@ struct DownloadRequest : Request { Q_DECLARE_METATYPE(DownloadRequest); +/** + * Change TranslateLocally resource usage for this session. + * + * Request: + * { + * "id": int, + * "command": "Configure", + * "data": { + * "threads": int + * "cacheSize": int (0 means disabled) + * } + * } + * + * Successful response: + * { + * "id": int, + * "success": true, + * "data": {} + * } + */ +struct ConfigureRequest : Request { + int threads; + int cacheSize; +}; + +Q_DECLARE_METATYPE(ConfigureRequest); + /** * Internal structure to handle a request that is missing a required field. */ @@ -243,7 +270,7 @@ struct MalformedRequest : Request { QString error; }; -using request_variant = std::variant; +using request_variant = std::variant; /** * Internal structure to cache a loaded direct model (i.e. no pivoting) @@ -410,6 +437,10 @@ private slots: */ void handleRequest(DownloadRequest myJsonInput); + /** + */ + void handleRequest(ConfigureRequest myJsonInput); + /** * @brief handleRequest handles a request type MalformedRequest and writes to stdout * @param myJsonInput MalformedRequest