Skip to content

Commit

Permalink
Add half precision inference and document
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Dec 19, 2023
1 parent 39e2a96 commit c583555
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
3 changes: 3 additions & 0 deletions docs/en/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ This command will create a `codes_N` file in the working directory, where N is a
You may want to use `--compile` to fuse CUDA kernels for faster inference (~30 tokens/second -> ~500 tokens/second).
Correspondingly, if you do not plan to use acceleration, you can comment out the `--compile` parameter.

!!! info
For GPUs that do not support bf16, you may need to use the `--half` parameter.

### 3. Generate vocals from semantic tokens:
```bash
python tools/vqgan/inference.py \
Expand Down
3 changes: 3 additions & 0 deletions docs/zh/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ python tools/llama/generate.py \
您可能希望使用 `--compile` 来融合 cuda 内核以实现更快的推理 (~30 个 token/秒 -> ~500 个 token/秒).
对应的, 如果你不打算使用加速, 你可以注释掉 `--compile` 参数.

!!! info
对于不支持 bf16 的 GPU, 你可能需要使用 `--half` 参数.

### 3. 从语义 token 生成人声:
```bash
python tools/vqgan/inference.py \
Expand Down
8 changes: 6 additions & 2 deletions fish_speech/models/text2semantic/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, config: ModelArgs) -> None:
self.max_batch_size = -1
self.max_seq_len = -1

def setup_caches(self, max_batch_size, max_seq_len):
def setup_caches(self, max_batch_size, max_seq_len, dtype=torch.bfloat16):
if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
return

Expand All @@ -121,7 +121,11 @@ def setup_caches(self, max_batch_size, max_seq_len):

for b in self.layers:
b.attention.kv_cache = KVCache(
max_batch_size, max_seq_len, self.config.n_local_heads, head_dim
max_batch_size,
max_seq_len,
self.config.n_local_heads,
head_dim,
dtype=dtype,
)

def embed(self, x: Tensor) -> Tensor:
Expand Down
9 changes: 7 additions & 2 deletions tools/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def generate(
prompt: torch.Tensor,
max_new_tokens: int,
eos_token_id: int = 2,
precision: torch.dtype = torch.bfloat16,
**sampling_kwargs,
) -> torch.Tensor:
"""
Expand All @@ -228,7 +229,7 @@ def generate(

device, dtype = prompt.device, prompt.dtype
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_len=T_new)
model.setup_caches(max_batch_size=1, max_seq_len=T_new, dtype=precision)

codebook_dim = 1 + model.config.num_codebooks
# create an empty tensor of the expected final shape and fill in the current tokens
Expand Down Expand Up @@ -381,6 +382,7 @@ def load_model(config_name, checkpoint_path, device, precision):
@click.option("--use-g2p/--no-g2p", default=True)
@click.option("--seed", type=int, default=42)
@click.option("--speaker", type=str, default=None)
@click.option("--half/--no-half", default=False)
def main(
text: str,
prompt_text: Optional[str],
Expand All @@ -398,9 +400,11 @@ def main(
use_g2p: bool,
seed: int,
speaker: Optional[str],
half: bool,
) -> None:
device = "cuda"
precision = torch.bfloat16

precision = torch.half if half else torch.bfloat16

logger.info("Loading model ...")
t0 = time.time()
Expand Down Expand Up @@ -445,6 +449,7 @@ def main(
prompt=encoded,
max_new_tokens=max_new_tokens,
eos_token_id=tokenizer.eos_token_id,
precision=precision,
temperature=temperature,
top_k=top_k,
top_p=top_p,
Expand Down

0 comments on commit c583555

Please sign in to comment.