-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
57 lines (48 loc) · 2.2 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model import *
import torch
import os
from PIL import Image
class ChartCoder:
def __init__(
self,
temperature=0.1,
max_tokens=2048,
top_p=0.95,
context_length=2048,
):
self.temperature = temperature
self.max_tokens = max_tokens
self.top_p = top_p
self.context_length = context_length
# Note: change to you path
pretrained = "/mnt/afs/chartcoder"
model_name = "llava_deepseekcoder"
device_map = "auto"
self.system_message = ""
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map)
model.eval()
self.tokenizer = tokenizer
self.model = model
self.image_processor = image_processor
self.IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX
def generate(self, conversation):
image = Image.open(conversation[0]['content'][1]['image_url']).convert('RGB')
prompt = self.system_message + f"### Instruction:\n{DEFAULT_IMAGE_TOKEN}\n{conversation[0]['content'][0]['text']}\n### Response:\n"
input_ids = tokenizer_image_token(prompt, self.tokenizer, self.IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
image_tensor = process_images([image], self.image_processor, self.model.config)[0]
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half().cuda(),
image_sizes=[image.size],
do_sample=True if self.temperature > 0 else False,
temperature=self.temperature,
top_p=self.top_p,
max_new_tokens=self.max_tokens,
use_cache=True)
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return outputs