Skip to content

Commit

Permalink
Add option to skip special tokens in TextStreamer (#1139)
Browse files Browse the repository at this point in the history
* Add option to skip special tokens in TextStreamer to be like WhisperTextStreamer

* Re-order decode kwargs

---------

Co-authored-by: Joshua Lochner <[email protected]>
  • Loading branch information
sroussey and xenova authored Jan 15, 2025
1 parent a938a56 commit 1cf1a45
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/generation/streamers.js
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export class TextStreamer extends BaseStreamer {
* @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer
* @param {Object} options
* @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens
* @param {boolean} [options.skip_special_tokens=true] Whether to skip special tokens when decoding
* @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display
* @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated
* @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method
Expand All @@ -45,6 +46,7 @@ export class TextStreamer extends BaseStreamer {
skip_prompt = false,
callback_function = null,
token_callback_function = null,
skip_special_tokens = true,
decode_kwargs = {},
...kwargs
} = {}) {
Expand All @@ -53,7 +55,7 @@ export class TextStreamer extends BaseStreamer {
this.skip_prompt = skip_prompt;
this.callback_function = callback_function ?? stdout_write;
this.token_callback_function = token_callback_function;
this.decode_kwargs = { ...decode_kwargs, ...kwargs };
this.decode_kwargs = { skip_special_tokens, ...decode_kwargs, ...kwargs };

// variables used in the streaming process
this.token_cache = [];
Expand Down Expand Up @@ -169,9 +171,10 @@ export class WhisperTextStreamer extends TextStreamer {
} = {}) {
super(tokenizer, {
skip_prompt,
skip_special_tokens,
callback_function,
token_callback_function,
decode_kwargs: { skip_special_tokens, ...decode_kwargs },
decode_kwargs,
});
this.timestamp_begin = tokenizer.timestamp_begin;

Expand Down

0 comments on commit 1cf1a45

Please sign in to comment.