Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add configure command to native messaging interface #103

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions scripts/native_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ async def translate(self, text, src=None, trg=None, *, model=None, pivot=None, h
async def download_model(self, model_id, *, update=lambda data: None):
return await self.request("DownloadModel", {"modelID": str(model_id)}, update=update)

async def configure(self, *, threads:int = None, cache_size:int = None):
options = {}
if threads is not None:
options["threads"] = int(threads)
if cache_size is not None:
options["cacheSize"] = int(cache_size)
return await self.request("Configure", options)


def first(iterable, *default):
"""Returns the first value of anything iterable, or throws StopIteration
Expand Down Expand Up @@ -186,6 +194,8 @@ async def test():
if model["id"] == selected_model["id"]
)

await tl.configure(threads=1, cache_size=0)

# Perform some translations, switching between the models
translations = await asyncio.gather(
tl.translate("Hello world!", "en", "de"),
Expand Down
23 changes: 21 additions & 2 deletions src/cli/NativeMsgIface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,17 @@ void NativeMsgIface::handleRequest(ListRequest request) {
writeResponse(request, modelsJson);
}

void NativeMsgIface::handleRequest(ConfigureRequest request) {
marian::bergamot::AsyncService::Config serviceConfig;
serviceConfig.numWorkers = request.threads;
serviceConfig.cacheSize = request.cacheSize;
service_.reset();
service_ = std::make_shared<marian::bergamot::AsyncService>(serviceConfig);

QJsonObject response{}; // I don't know...
writeResponse(request, response);
}

void NativeMsgIface::handleRequest(DownloadRequest request) {
// Edge case: client issued a DownloadRequest before fetching the list of
// remote models because it knows the model ID from a previous run. We still
Expand Down Expand Up @@ -284,7 +295,7 @@ request_variant NativeMsgIface::parseJsonInput(QByteArray input) {

// Define what are mandatory and what are optional request keys
static const QStringList mandatoryKeys({"command", "id", "data"}); // Expected in every message
static const QSet<QString> commandTypes({"ListModels", "DownloadModel", "Translate"});
static const QSet<QString> commandTypes({"ListModels", "DownloadModel", "Translate", "Configure"});
// Json doesn't have schema validation, so validate here, in place:
QString command;
int id;
Expand Down Expand Up @@ -359,13 +370,21 @@ request_variant NativeMsgIface::parseJsonInput(QByteArray input) {
ret.id = id;
for (auto&& key : mandatoryKeysDownload) {
QJsonValueRef val = data[key];
if (val.isNull()) {
if (val.isNull() || val.isUndefined()) {
Copy link
Owner

@XapaJIaMnu XapaJIaMnu May 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So... isNull() is like key: "", whereas isUndefined() is the key is not found at all?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How was it then working beforehand!?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea, nor have I tested missing keys properly? They definitely look pretty distinct and incompatible:

inline bool isNull() const { return type() == Null; }
...
inline bool isUndefined() const { return type() == Undefined; }

(From qjsonvalue.h#107-113)

return MalformedRequest{id, QString("data field key %1 cannot be null!").arg(key)};
} else {
ret.modelID = val.toString();
}
}
return ret;
} else if (command == "Configure") {
ConfigureRequest ret;
ret.id = id;
if (!data["threads"].isUndefined())
ret.threads = data["threads"].toInt();
if (!data["cacheSize"].isUndefined())
ret.cacheSize = data["cacheSize"].toInt();
return ret;
} else {
return MalformedRequest{id, QString("Developer error. We shouldn't ever be here! Command: %1").arg(command)};
}
Expand Down
33 changes: 32 additions & 1 deletion src/cli/NativeMsgIface.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,14 +236,41 @@ struct DownloadRequest : Request {

Q_DECLARE_METATYPE(DownloadRequest);

/**
* Change TranslateLocally resource usage for this session.
*
* Request:
* {
* "id": int,
* "command": "Configure",
* "data": {
* "threads": int
* "cacheSize": int (0 means disabled)
* }
* }
*
* Successful response:
* {
* "id": int,
* "success": true,
* "data": {}
* }
*/
struct ConfigureRequest : Request {
int threads;
int cacheSize;
};

Q_DECLARE_METATYPE(ConfigureRequest);

/**
* Internal structure to handle a request that is missing a required field.
*/
struct MalformedRequest : Request {
QString error;
};

using request_variant = std::variant<TranslationRequest, ListRequest, DownloadRequest, MalformedRequest>;
using request_variant = std::variant<TranslationRequest, ListRequest, DownloadRequest, ConfigureRequest, MalformedRequest>;

/**
* Internal structure to cache a loaded direct model (i.e. no pivoting)
Expand Down Expand Up @@ -410,6 +437,10 @@ private slots:
*/
void handleRequest(DownloadRequest myJsonInput);

/**
*/
void handleRequest(ConfigureRequest myJsonInput);

/**
* @brief handleRequest handles a request type MalformedRequest and writes to stdout
* @param myJsonInput MalformedRequest
Expand Down