Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add static method for read model info from gguf #87

Merged
merged 6 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ private static String getCpuFeatures() {
}
}

protected static native WritableMap modelInfo(
String model,
String[] skip
);
protected static native long initContext(
String model,
boolean embedding,
Expand Down
29 changes: 29 additions & 0 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,35 @@ public void setContextLimit(double limit, Promise promise) {
promise.resolve(null);
}

public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
protected WritableMap doInBackground(Void... voids) {
try {
String[] skipArray = new String[skip.size()];
for (int i = 0; i < skip.size(); i++) {
skipArray[i] = skip.getString(i);
}
return LlamaContext.modelInfo(model, skipArray);
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(WritableMap result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
}

public void initContext(double id, final ReadableMap params, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
Expand Down
69 changes: 68 additions & 1 deletion android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
#include <thread>
#include <unordered_map>
#include "llama.h"
#include "rn-llama.hpp"
#include "llama-impl.h"
#include "ggml.h"
#include "rn-llama.hpp"

#define UNUSED(x) (void)(x)
#define TAG "RNLLAMA_ANDROID_JNI"
Expand Down Expand Up @@ -132,6 +133,72 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v
env->CallVoidMethod(map, putArrayMethod, jKey, value);
}

JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_modelInfo(
JNIEnv *env,
jobject thiz,
jstring model_path_str,
jobjectArray skip
) {
UNUSED(thiz);

const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);

std::vector<std::string> skip_vec;
int skip_len = env->GetArrayLength(skip);
for (int i = 0; i < skip_len; i++) {
jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i);
const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr);
skip_vec.push_back(skip_chars);
env->ReleaseStringUTFChars(skip_str, skip_chars);
}

struct lm_gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ NULL,
};
struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params);

if (!ctx) {
LOGI("%s: failed to load '%s'\n", __func__, model_path_chars);
return nullptr;
}

auto info = createWriteableMap(env);
putInt(env, info, "version", lm_gguf_get_version(ctx));
putInt(env, info, "alignment", lm_gguf_get_alignment(ctx));
putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx));
{
const int n_kv = lm_gguf_get_n_kv(ctx);

for (int i = 0; i < n_kv; ++i) {
const char * key = lm_gguf_get_key(ctx, i);

bool skipped = false;
if (skip_len > 0) {
for (int j = 0; j < skip_len; j++) {
if (skip_vec[j] == key) {
skipped = true;
break;
}
}
}

if (skipped) {
continue;
}

const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i);
putString(env, info, key, value.c_str());
}
}

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
lm_gguf_free(ctx);

return reinterpret_cast<jobject>(info);
}

struct callback_context {
JNIEnv *env;
rnllama::llama_rn_context *llama;
Expand Down
5 changes: 5 additions & 0 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ public void setContextLimit(double limit, Promise promise) {
rnllama.setContextLimit(limit, promise);
}

@ReactMethod
public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
rnllama.modelInfo(model, skip, promise);
}

@ReactMethod
public void initContext(double id, final ReadableMap params, final Promise promise) {
rnllama.initContext(id, params, promise);
Expand Down
5 changes: 5 additions & 0 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ public void setContextLimit(double limit, Promise promise) {
rnllama.setContextLimit(limit, promise);
}

@ReactMethod
public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
rnllama.modelInfo(model, skip, promise);
}

@ReactMethod
public void initContext(double id, final ReadableMap params, final Promise promise) {
rnllama.initContext(id, params, promise);
Expand Down
56 changes: 56 additions & 0 deletions cpp/rn-llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,67 @@
#include <sstream>
#include <iostream>
#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "llama-impl.h"
#include "sampling.h"

namespace rnllama {

static std::string lm_gguf_data_to_str(enum lm_gguf_type type, const void * data, int i) {
switch (type) {
case LM_GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]);
case LM_GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]);
case LM_GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]);
case LM_GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]);
case LM_GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]);
case LM_GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]);
case LM_GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]);
case LM_GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]);
case LM_GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]);
case LM_GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]);
case LM_GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false";
default: return "unknown type: {}"; // TODO
}
}

static std::string lm_gguf_kv_to_str(const struct lm_gguf_context * ctx_gguf, int i) {
const enum lm_gguf_type type = lm_gguf_get_kv_type(ctx_gguf, i);

switch (type) {
case LM_GGUF_TYPE_STRING:
return lm_gguf_get_val_str(ctx_gguf, i);
case LM_GGUF_TYPE_ARRAY:
{
const enum lm_gguf_type arr_type = lm_gguf_get_arr_type(ctx_gguf, i);
int arr_n = lm_gguf_get_arr_n(ctx_gguf, i);
const void * data = lm_gguf_get_arr_data(ctx_gguf, i);
std::stringstream ss;
ss << "[";
for (int j = 0; j < arr_n; j++) {
if (arr_type == LM_GGUF_TYPE_STRING) {
std::string val = lm_gguf_get_arr_str(ctx_gguf, i, j);
// escape quotes
replace_all(val, "\\", "\\\\");
replace_all(val, "\"", "\\\"");
ss << '"' << val << '"';
} else if (arr_type == LM_GGUF_TYPE_ARRAY) {
ss << "???";
} else {
ss << lm_gguf_data_to_str(arr_type, data, j);
}
if (j < arr_n - 1) {
ss << ", ";
}
}
ss << "]";
return ss.str();
}
default:
return lm_gguf_data_to_str(type, lm_gguf_get_val_data(ctx_gguf, i), 0);
}
}

static void llama_batch_clear(llama_batch *batch) {
batch->n_tokens = 0;
}
Expand Down
37 changes: 29 additions & 8 deletions docs/API/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ llama.rn

- [convertJsonSchemaToGrammar](README.md#convertjsonschematogrammar)
- [initLlama](README.md#initllama)
- [loadLlamaModelInfo](README.md#loadllamamodelinfo)
- [releaseAllLlama](README.md#releaseallllama)
- [setContextLimit](README.md#setcontextlimit)

Expand All @@ -43,7 +44,7 @@ llama.rn

#### Defined in

[index.ts:52](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L52)
[index.ts:52](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L52)

___

Expand All @@ -53,7 +54,7 @@ ___

#### Defined in

[index.ts:44](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L44)
[index.ts:44](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L44)

___

Expand All @@ -63,7 +64,7 @@ ___

#### Defined in

[index.ts:42](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L42)
[index.ts:42](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L42)

___

Expand All @@ -80,7 +81,7 @@ ___

#### Defined in

[index.ts:32](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L32)
[index.ts:32](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L32)

## Functions

Expand All @@ -104,7 +105,7 @@ ___

#### Defined in

[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/41b779f/src/grammar.ts#L824)
[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L824)

___

Expand All @@ -125,7 +126,27 @@ ___

#### Defined in

[index.ts:196](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L196)
[index.ts:208](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L208)

___

### loadLlamaModelInfo

▸ **loadLlamaModelInfo**(`model`): `Promise`<`Object`\>

#### Parameters

| Name | Type |
| :------ | :------ |
| `model` | `string` |

#### Returns

`Promise`<`Object`\>

#### Defined in

[index.ts:202](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L202)

___

Expand All @@ -139,7 +160,7 @@ ___

#### Defined in

[index.ts:233](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L233)
[index.ts:245](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L245)

___

Expand All @@ -159,4 +180,4 @@ ___

#### Defined in

[index.ts:188](https://github.com/mybigday/llama.rn/blob/41b779f/src/index.ts#L188)
[index.ts:188](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L188)
Loading
Loading