Skip to content

Commit

Permalink
small refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
VladOS95-cyber committed Nov 5, 2024
1 parent b43581e commit 6e8fb0d
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions examples/inference/distributed/llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import time
from concurrent.futures import ThreadPoolExecutor

from itertools import chain
import fire
import torch
from datasets import load_dataset
Expand Down Expand Up @@ -100,11 +99,9 @@ def main(
else:
print(f"Directory '{save_dir}' already exists.")

captions = load_dataset("nkp37/OpenVid-1M", split="train")["caption"][:max_captions]

# split long-text captions into small sentences
splitted_captions = list(chain.from_iterable([captions[i].split(".") for i in range(len(captions))]))
batches = get_batches(splitted_captions, batch_size)
captions = load_dataset("nkp37/OpenVid-1M", split="train")["caption"]
reduced_captions = captions[: min(len(captions), max_captions)]
batches = get_batches(reduced_captions, batch_size)

output_queue = queue.Queue()
save_thread = ThreadPoolExecutor(max_workers=num_workers)
Expand All @@ -114,7 +111,7 @@ def main(
with distributed_state.split_between_processes(caption_batch) as caption:
input = processor(caption, padding=True, return_tensors="pt").to(model.device)
output = model.generate(**input, max_new_tokens=max_new_tokens)
generated_text = processor.decode(output[0][2:], skip_special_tokens=True)
generated_text = processor.batch_decode(output, skip_special_tokens=True)
output_queue.put((caption, generated_text))
finally:
output_queue.put(None)
Expand Down

0 comments on commit 6e8fb0d

Please sign in to comment.