Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update embedding method #88

Merged
merged 3 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.getString("model"),
// boolean embedding,
params.hasKey("embedding") ? params.getBoolean("embedding") : false,
// int embd_normalize,
params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1,
// int n_ctx,
params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512,
// int n_batch,
Expand All @@ -66,9 +68,14 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
// float rope_freq_scale
params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f,
// int pooling_type,
params.hasKey("pooling_type") ? params.getInt("pooling_type") : -1,
// LoadProgressCallback load_progress_callback
params.hasKey("use_progress_callback") ? new LoadProgressCallback(this) : null
);
if (this.context == -1) {
throw new IllegalStateException("Failed to initialize context");
}
this.modelDetails = loadModelDetails(this.context);
this.reactContext = reactContext;
}
Expand Down Expand Up @@ -258,11 +265,16 @@ public String detokenize(ReadableArray tokens) {
return detokenize(this.context, toks);
}

public WritableMap getEmbedding(String text) {
public WritableMap getEmbedding(String text, ReadableMap params) {
if (isEmbeddingEnabled(this.context) == false) {
throw new IllegalStateException("Embedding is not enabled");
}
WritableMap result = embedding(this.context, text);
WritableMap result = embedding(
this.context,
text,
// int embd_normalize,
params.hasKey("embd_normalize") ? params.getInt("embd_normalize") : -1
);
if (result.hasKey("error")) {
throw new IllegalStateException(result.getString("error"));
}
Expand Down Expand Up @@ -365,6 +377,7 @@ protected static native WritableMap modelInfo(
protected static native long initContext(
String model,
boolean embedding,
int embd_normalize,
int n_ctx,
int n_batch,
int n_threads,
Expand All @@ -376,6 +389,7 @@ protected static native long initContext(
float lora_scaled,
float rope_freq_base,
float rope_freq_scale,
int pooling_type,
LoadProgressCallback load_progress_callback
);
protected static native void interruptLoad(long contextPtr);
Expand Down Expand Up @@ -429,7 +443,11 @@ protected static native WritableMap doCompletion(
protected static native WritableArray tokenize(long contextPtr, String text);
protected static native String detokenize(long contextPtr, int[] tokens);
protected static native boolean isEmbeddingEnabled(long contextPtr);
protected static native WritableMap embedding(long contextPtr, String text);
protected static native WritableMap embedding(
long contextPtr,
String text,
int embd_normalize
);
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
protected static native void freeContext(long contextPtr);
protected static native void logToAndroid();
Expand Down
4 changes: 2 additions & 2 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ protected void onPostExecute(String result) {
tasks.put(task, "detokenize-" + contextId);
}

public void embedding(double id, final String text, final Promise promise) {
public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;
Expand All @@ -361,7 +361,7 @@ protected WritableMap doInBackground(Void... voids) {
if (context == null) {
throw new Exception("Context not found");
}
return context.getEmbedding(text);
return context.getEmbedding(text, params);
} catch (Exception e) {
exception = e;
}
Expand Down
58 changes: 51 additions & 7 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ static inline void pushDouble(JNIEnv *env, jobject arr, double value) {
env->CallVoidMethod(arr, pushDoubleMethod, value);
}

// Method to push string into WritableArray
static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
jmethodID pushStringMethod = env->GetMethodID(mapClass, "pushString", "(Ljava/lang/String;)V");

jstring jValue = env->NewStringUTF(value);
env->CallVoidMethod(arr, pushStringMethod, jValue);
}

// Method to push WritableMap into WritableArray
static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
Expand Down Expand Up @@ -213,6 +222,7 @@ Java_com_rnllama_LlamaContext_initContext(
jobject thiz,
jstring model_path_str,
jboolean embedding,
jint embd_normalize,
jint n_ctx,
jint n_batch,
jint n_threads,
Expand All @@ -224,6 +234,7 @@ Java_com_rnllama_LlamaContext_initContext(
jfloat lora_scaled,
jfloat rope_freq_base,
jfloat rope_freq_scale,
jint pooling_type,
jobject load_progress_callback
) {
UNUSED(thiz);
Expand All @@ -238,11 +249,22 @@ Java_com_rnllama_LlamaContext_initContext(
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
defaultParams.model = model_path_chars;

defaultParams.embedding = embedding;

defaultParams.n_ctx = n_ctx;
defaultParams.n_batch = n_batch;

if (pooling_type != -1) {
defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
}

defaultParams.embedding = embedding;
if (embd_normalize != -1) {
defaultParams.embd_normalize = embd_normalize;
}
if (embedding) {
// For non-causal models, batch size must be equal to ubatch size
defaultParams.n_ubatch = defaultParams.n_batch;
}

int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
Expand Down Expand Up @@ -291,16 +313,21 @@ Java_com_rnllama_LlamaContext_initContext(

bool is_model_loaded = llama->loadModel(defaultParams);

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
env->ReleaseStringUTFChars(lora_str, lora_chars);

LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
if (is_model_loaded) {
if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
llama_free(llama->ctx);
return -1;
}
context_map[(long) llama->ctx] = llama;
} else {
llama_free(llama->ctx);
}

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
env->ReleaseStringUTFChars(lora_str, lora_chars);

return reinterpret_cast<jlong>(llama->ctx);
}

Expand Down Expand Up @@ -745,10 +772,21 @@ Java_com_rnllama_LlamaContext_isEmbeddingEnabled(

JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_embedding(
JNIEnv *env, jobject thiz, jlong context_ptr, jstring text) {
JNIEnv *env, jobject thiz,
jlong context_ptr,
jstring text,
jint embd_normalize
) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];

common_params embdParams;
embdParams.embedding = true;
embdParams.embd_normalize = llama->params.embd_normalize;
if (embd_normalize != -1) {
embdParams.embd_normalize = embd_normalize;
}

const char *text_chars = env->GetStringUTFChars(text, nullptr);

llama->rewind();
Expand All @@ -769,14 +807,20 @@ Java_com_rnllama_LlamaContext_embedding(
llama->loadPrompt();
llama->doCompletion();

std::vector<float> embedding = llama->getEmbedding();
std::vector<float> embedding = llama->getEmbedding(embdParams);

auto embeddings = createWritableArray(env);
for (const auto &val : embedding) {
pushDouble(env, embeddings, (double) val);
}
putArray(env, result, "embedding", embeddings);

auto promptTokens = createWritableArray(env);
for (const auto &tok : llama->embd) {
pushString(env, promptTokens, common_token_to_piece(llama->ctx, tok).c_str());
}
putArray(env, result, "prompt_tokens", promptTokens);

env->ReleaseStringUTFChars(text, text_chars);
return result;
}
Expand Down
4 changes: 2 additions & 2 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ public void detokenize(double id, final ReadableArray tokens, final Promise prom
}

@ReactMethod
public void embedding(double id, final String text, final Promise promise) {
rnllama.embedding(id, text, promise);
public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
rnllama.embedding(id, text, params, promise);
}

@ReactMethod
Expand Down
4 changes: 2 additions & 2 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ public void detokenize(double id, final ReadableArray tokens, final Promise prom
}

@ReactMethod
public void embedding(double id, final String text, final Promise promise) {
rnllama.embedding(id, text, promise);
public void embedding(double id, final String text, final ReadableMap params, final Promise promise) {
rnllama.embedding(id, text, params, promise);
}

@ReactMethod
Expand Down
17 changes: 9 additions & 8 deletions cpp/rn-llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,28 +595,29 @@ struct llama_rn_context
return token_with_probs;
}

std::vector<float> getEmbedding()
std::vector<float> getEmbedding(common_params &embd_params)
{
static const int n_embd = llama_n_embd(llama_get_model(ctx));
if (!params.embedding)
if (!embd_params.embedding)
{
LOG_WARNING("embedding disabled, embedding: %s", params.embedding);
LOG_WARNING("embedding disabled, embedding: %s", embd_params.embedding);
return std::vector<float>(n_embd, 0.0f);
}
float *data;

if(params.pooling_type == 0){
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
printf("pooling_type: %d\n", pooling_type);
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
data = llama_get_embeddings(ctx);
}
else {
} else {
data = llama_get_embeddings_seq(ctx, 0);
}

if(!data) {
if (!data) {
return std::vector<float>(n_embd, 0.0f);
}
std::vector<float> embedding(data, data + n_embd), out(data, data + n_embd);
common_embd_normalize(embedding.data(), out.data(), n_embd, params.embd_normalize);
common_embd_normalize(embedding.data(), out.data(), n_embd, embd_params.embd_normalize);
return out;
}

Expand Down
33 changes: 22 additions & 11 deletions docs/API/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ llama.rn
- [BenchResult](README.md#benchresult)
- [CompletionParams](README.md#completionparams)
- [ContextParams](README.md#contextparams)
- [EmbeddingParams](README.md#embeddingparams)
- [TokenData](README.md#tokendata)

### Functions
Expand Down Expand Up @@ -44,7 +45,7 @@ llama.rn

#### Defined in

[index.ts:52](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L52)
[index.ts:57](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L57)

___

Expand All @@ -54,17 +55,27 @@ ___

#### Defined in

[index.ts:44](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L44)
[index.ts:49](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L49)

___

### ContextParams

Ƭ **ContextParams**: `NativeContextParams`
Ƭ **ContextParams**: `Omit`<`NativeContextParams`, ``"pooling_type"``\> & { `pooling_type?`: ``"none"`` \| ``"mean"`` \| ``"cls"`` \| ``"last"`` \| ``"rank"`` }

#### Defined in

[index.ts:42](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L42)
[index.ts:43](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L43)

___

### EmbeddingParams

Ƭ **EmbeddingParams**: `NativeEmbeddingParams`

#### Defined in

[index.ts:47](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L47)

___

Expand All @@ -81,7 +92,7 @@ ___

#### Defined in

[index.ts:32](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L32)
[index.ts:33](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L33)

## Functions

Expand All @@ -105,7 +116,7 @@ ___

#### Defined in

[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/grammar.ts#L824)
[grammar.ts:824](https://github.com/mybigday/llama.rn/blob/20a1819/src/grammar.ts#L824)

___

Expand All @@ -117,7 +128,7 @@ ___

| Name | Type |
| :------ | :------ |
| `«destructured»` | `NativeContextParams` |
| `«destructured»` | [`ContextParams`](README.md#contextparams) |
| `onProgress?` | (`progress`: `number`) => `void` |

#### Returns
Expand All @@ -126,7 +137,7 @@ ___

#### Defined in

[index.ts:208](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L208)
[index.ts:225](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L225)

___

Expand All @@ -146,7 +157,7 @@ ___

#### Defined in

[index.ts:202](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L202)
[index.ts:210](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L210)

___

Expand All @@ -160,7 +171,7 @@ ___

#### Defined in

[index.ts:245](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L245)
[index.ts:269](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L269)

___

Expand All @@ -180,4 +191,4 @@ ___

#### Defined in

[index.ts:188](https://github.com/mybigday/llama.rn/blob/66d2ed3/src/index.ts#L188)
[index.ts:196](https://github.com/mybigday/llama.rn/blob/20a1819/src/index.ts#L196)
Loading
Loading