Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename FlaxWhisperPipline -> FlaxWhisperPipeline #114

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit

## Pipeline Usage

The recommended way of running Whisper JAX is through the [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) abstraction class. This class handles all
The recommended way of running Whisper JAX is through the [`FlaxWhisperPipeline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) abstraction class. This class handles all
the necessary pre- and post-processing, as well as wrapping the generate method for data parallelism across accelerator devices.

Whisper JAX makes use of JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function for data parallelism across GPU/TPU devices. This function is _Just In Time (JIT)_
compiled the first time it is called. Thereafter, the function will be _cached_, enabling it to be run in super-fast time:

```python
from whisper_jax import FlaxWhisperPipline
from whisper_jax import FlaxWhisperPipeline

# instantiate pipeline
pipeline = FlaxWhisperPipline("openai/whisper-large-v2")
pipeline = FlaxWhisperPipeline("openai/whisper-large-v2")

# JIT compile the forward call - slow, but we only do once
text = pipeline("audio.mp3")
Expand All @@ -59,11 +59,11 @@ of the model weights.

For most GPUs, the dtype should be set to `jnp.float16`. For A100 GPUs or TPUs, the dtype should be set to `jnp.bfloat16`:
```python
from whisper_jax import FlaxWhisperPipline
from whisper_jax import FlaxWhisperPipeline
import jax.numpy as jnp

# instantiate pipeline in bfloat16
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16)
pipeline = FlaxWhisperPipeline("openai/whisper-large-v2", dtype=jnp.bfloat16)
```

### Batching
Expand All @@ -75,10 +75,10 @@ provides a 10x speed-up compared to transcribing the audio samples sequentially,
To enable batching, pass the `batch_size` parameter when you instantiate the pipeline:

```python
from whisper_jax import FlaxWhisperPipline
from whisper_jax import FlaxWhisperPipeline

# instantiate pipeline with batching
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16)
pipeline = FlaxWhisperPipeline("openai/whisper-large-v2", batch_size=16)
```

### Task
Expand All @@ -93,7 +93,7 @@ text = pipeline("audio.mp3", task="translate")

### Timestamps

The [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) also supports timestamp prediction. Note that enabling timestamps will require a second JIT compilation of the
The [`FlaxWhisperPipeline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) also supports timestamp prediction. Note that enabling timestamps will require a second JIT compilation of the
forward call, this time including the timestamp outputs:

```python
Expand All @@ -108,11 +108,11 @@ In the following code snippet, we instantiate the model in bfloat16 precision wi
returning timestamps tokens:

```python
from whisper_jax import FlaxWhisperPipline
from whisper_jax import FlaxWhisperPipeline
import jax.numpy as jnp

# instantiate pipeline with bfloat16 and enable batching
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)
pipeline = FlaxWhisperPipeline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)

# transcribe and return timestamps
outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
Expand Down Expand Up @@ -188,7 +188,7 @@ the next time they are required. Note that converting weights from PyTorch to Fl

For example, to convert the fine-tuned checkpoint [`sanchit-gandhi/whisper-small-hi`](https://huggingface.co/sanchit-gandhi/whisper-small-hi) from the blog post [Fine-Tuning Whisper](https://huggingface.co/blog/fine-tune-whisper):
```python
from whisper_jax import FlaxWhisperForConditionalGeneration, FlaxWhisperPipline
from whisper_jax import FlaxWhisperForConditionalGeneration, FlaxWhisperPipeline
import jax.numpy as jnp

checkpoint_id = "sanchit-gandhi/whisper-small-hi"
Expand All @@ -198,7 +198,7 @@ model = FlaxWhisperForConditionalGeneration.from_pretrained(checkpoint_id, from_
model.push_to_hub(checkpoint_id)

# now we can load the Flax weights directly as required
pipeline = FlaxWhisperPipline(checkpoint_id, dtype=jnp.bfloat16, batch_size=16)
pipeline = FlaxWhisperPipeline(checkpoint_id, dtype=jnp.bfloat16, batch_size=16)
```

## Advanced Usage
Expand All @@ -212,7 +212,7 @@ The following code snippet demonstrates how data parallelism can be achieved usi
an entirely equivalent way to `pmap`:

```python
from whisper_jax import FlaxWhisperPipline
from whisper_jax import FlaxWhisperPipeline
import jax.numpy as jnp

# 2D parameter and activation partitioning for DP
Expand All @@ -230,7 +230,7 @@ logical_axis_rules_dp = (
("channels", None),
)

pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)
pipeline = FlaxWhisperPipeline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)
pipeline.shard_params(num_mp_partitions=1, logical_axis_rules=logical_axis_rules_dp)
```

Expand Down
4 changes: 2 additions & 2 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
from transformers.pipelines.audio_utils import ffmpeg_read

from whisper_jax import FlaxWhisperPipline
from whisper_jax import FlaxWhisperPipeline


cc.initialize_cache("./jax_cache")
Expand Down Expand Up @@ -73,7 +73,7 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal


if __name__ == "__main__":
pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
pipeline = FlaxWhisperPipeline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
stride_length_s = CHUNK_LENGTH_S / 6
chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
Expand Down
2 changes: 1 addition & 1 deletion whisper-jax-tpu.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion whisper_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@

from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
from .partitioner import PjitPartitioner
from .pipeline import FlaxWhisperPipline
from .pipeline import FlaxWhisperPipeline
from .train_state import InferenceState
2 changes: 1 addition & 1 deletion whisper_jax/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
)


class FlaxWhisperPipline:
class FlaxWhisperPipeline:
def __init__(
self,
checkpoint="openai/whisper-large-v2",
Expand Down