From 1cf1a4554b9ad94982c07262f9744624e8ec8813 Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Wed, 15 Jan 2025 04:31:09 -0800 Subject: [PATCH] Add option to skip special tokens in TextStreamer (#1139) * Add option to skip special tokens in TextStreamer to be like WhisperTextStreamer * Re-order decode kwargs --------- Co-authored-by: Joshua Lochner --- src/generation/streamers.js | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/generation/streamers.js b/src/generation/streamers.js index effe2e01f..33c882081 100644 --- a/src/generation/streamers.js +++ b/src/generation/streamers.js @@ -37,6 +37,7 @@ export class TextStreamer extends BaseStreamer { * @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer * @param {Object} options * @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens + * @param {boolean} [options.skip_special_tokens=true] Whether to skip special tokens when decoding * @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display * @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated * @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method @@ -45,6 +46,7 @@ export class TextStreamer extends BaseStreamer { skip_prompt = false, callback_function = null, token_callback_function = null, + skip_special_tokens = true, decode_kwargs = {}, ...kwargs } = {}) { @@ -53,7 +55,7 @@ export class TextStreamer extends BaseStreamer { this.skip_prompt = skip_prompt; this.callback_function = callback_function ?? stdout_write; this.token_callback_function = token_callback_function; - this.decode_kwargs = { ...decode_kwargs, ...kwargs }; + this.decode_kwargs = { skip_special_tokens, ...decode_kwargs, ...kwargs }; // variables used in the streaming process this.token_cache = []; @@ -169,9 +171,10 @@ export class WhisperTextStreamer extends TextStreamer { } = {}) { super(tokenizer, { skip_prompt, + skip_special_tokens, callback_function, token_callback_function, - decode_kwargs: { skip_special_tokens, ...decode_kwargs }, + decode_kwargs, }); this.timestamp_begin = tokenizer.timestamp_begin;