Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
sqs committed Dec 27, 2023
1 parent 631cd21 commit 250dda0
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 6 deletions.
1 change: 0 additions & 1 deletion provider/docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ time p run -s docs-query 'making provider work in vscode' $(find ../../web/conte

TODOs:

- use indexeddb for more storage https://stackoverflow.com/questions/5663166/is-there-a-way-to-increase-the-size-of-localstorage-in-google-chrome-to-avoid-qu
- use worker for onnx https://huggingface.co/docs/transformers.js/tutorials/react#step-4-connecting-everything-together
- simplify cache interface
- deal with different content types (markdown/html) differently
Expand Down
4 changes: 2 additions & 2 deletions provider/docs/src/corpus/search/embeddings.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { describe, expect, test } from 'vitest'
import { indexCorpus, type CorpusSearchResult } from '..'
import { corpusData } from '../data'
import { doc } from '../index.test'
import { embeddingsSearch, embedText, similarity } from './embeddings'
import { embeddingsSearch, embedTextInThisScope, similarity } from './embeddings'

describe('embeddingsSearch', () => {
test('finds matches', async () => {
Expand All @@ -14,7 +14,7 @@ describe('embeddingsSearch', () => {

describe('embedText', () => {
test('embeds', async () => {
const s = await embedText('hello world')
const s = await embedTextInThisScope('hello world')
expect(s).toBeInstanceOf(Float32Array)
})
})
Expand Down
15 changes: 12 additions & 3 deletions provider/docs/src/corpus/search/embeddings.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { cos_sim, env, pipeline } from '@xenova/transformers'
import * as onnxWeb from 'onnxruntime-web'
import { type CorpusIndex, type CorpusSearchResult } from '..'
import { useWebWorker } from '../../env'
import { embedTextOnWorker } from '../../mlWorker/webWorkerClient'
import { memo, noopCache, type CorpusCache } from '../cache/cache'

// eslint-disable-next-line @typescript-eslint/prefer-optional-chain
Expand Down Expand Up @@ -51,10 +53,17 @@ function cachedEmbedText(text: string, cache: CorpusCache): Promise<Float32Array

const pipe = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2', {})

/**
* Embed the text and return the vector. Run in a worker in some environments.
*/
export const embedText = useWebWorker ? embedTextOnWorker : embedTextInThisScope

/**
* Embed the text and return the vector.
*
* Run in the current scope (instead of in a worker in some environments).
*/
export async function embedText(text: string): Promise<Float32Array> {
export async function embedTextInThisScope(text: string): Promise<Float32Array> {
try {
console.time('embed')
const out = await pipe(text, { pooling: 'mean', normalize: true })
Expand All @@ -70,8 +79,8 @@ export async function embedText(text: string): Promise<Float32Array> {
* Compute the cosine similarity of the two texts' embeddings vectors.
*/
export async function similarity(text1: string, text2: string): Promise<number> {
const emb1 = await embedText(text1)
const emb2 = await embedText(text2)
const emb1 = await embedTextInThisScope(text1)
const emb2 = await embedTextInThisScope(text2)
return cos_sim(emb1, emb2)
}

Expand Down
3 changes: 3 additions & 0 deletions provider/docs/src/env.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
const isWebWindowRuntime = typeof window !== 'undefined'

export const useWebWorker = isWebWindowRuntime
7 changes: 7 additions & 0 deletions provider/docs/src/mlWorker/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
export interface MLWorkerMessagePair<T extends string = string, A extends {} = {}, R extends {} = {}> {
type: T
request: { id: number; type: T; args: A }
response: { id: number; result: R }
}

export interface MLWorkerEmbedTextMessage extends MLWorkerMessagePair<'embedText', string, Float32Array> {}
27 changes: 27 additions & 0 deletions provider/docs/src/mlWorker/webWorker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/// <reference lib="webworker" />

import { embedTextInThisScope } from '../corpus/search/embeddings'
import { MLWorkerEmbedTextMessage, MLWorkerMessagePair } from './api'

declare var self: DedicatedWorkerGlobalScope

onRequest<MLWorkerEmbedTextMessage>(
'embedText',
async (text: string): Promise<Float32Array> => embedTextInThisScope(text)
)

// Tell our host we are ready.
self.postMessage('ready')

function onRequest<P extends MLWorkerMessagePair>(
type: P['type'],
handler: (args: P['request']['args']) => Promise<P['response']['result']>
): void {
self.addEventListener('message', async event => {
const request = event.data as P['request']
if (request.type === type) {
const response: P['response'] = { id: request.id, result: await handler(request.args) }
self.postMessage(response)
}
})
}
54 changes: 54 additions & 0 deletions provider/docs/src/mlWorker/webWorkerClient.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import { type embedTextInThisScope } from '../corpus/search/embeddings'
import { type MLWorkerEmbedTextMessage, type MLWorkerMessagePair } from './api'

export const embedTextOnWorker: typeof embedTextInThisScope = async (text: string): Promise<Float32Array> =>
sendMessage<MLWorkerEmbedTextMessage>('embedText', text)

async function sendMessage<P extends MLWorkerMessagePair>(
type: P['type'],
args: P['request']['args']
): Promise<P['response']['result']> {
const worker = await acquireWorker()
const id = nextID()
worker.postMessage({ id, type, args } satisfies P['request'])
return new Promise<P['response']['result']>(resolve => {
const onMessage = (event: MessageEvent): void => {
const response = event.data as P['response']
if (response.id === id) {
resolve(response.result)
worker.removeEventListener('message', onMessage)
}
}
worker.addEventListener('message', onMessage)
})
}

const NUM_WORKERS = navigator.hardwareConcurrency || 2

const workers: (Promise<Worker> | undefined)[] = []
let workerSeq = 0

/**
* Acquire a worker from the pool. Currently the acquisition is round-robin.
*/
async function acquireWorker(): Promise<Worker> {
const workerID = workerSeq++ % NUM_WORKERS
let workerInstance = workers[workerID]
if (!workerInstance) {
workerInstance = new Promise<Worker>(resolve => {
const worker = new Worker(new URL('./webWorker.ts', import.meta.url), { type: 'module' })

// Wait for worker to become ready. It sends a message when it is ready. The actual message
// doesn't matter.
worker.addEventListener('message', () => resolve(worker))
})
console.log('worker', workerID, 'is ready')
workers[workerID] = workerInstance
}
return workerInstance
}

let id = 1
function nextID(): number {
return id++
}

0 comments on commit 250dda0

Please sign in to comment.