Skip to content

Commit

Permalink
Add Llava model bringup (#1173)
Browse files Browse the repository at this point in the history
### Ticket
Fix #1177

### What's changed
Add Model tests for Llava Model

### Logs

[test_llava.log](https://github.com/user-attachments/files/18687248/test_llava.log)

### Current Issue
```
E           NotImplementedError: The following operators are not implemented: ['aten::masked_scatter']
```

And this will be solved in this PR -
#1188
  • Loading branch information
ashokkumarkannan1 authored Feb 8, 2025
1 parent 7286f5b commit b0a7d77
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 0 deletions.
Empty file.
65 changes: 65 additions & 0 deletions forge/test/models/pytorch/multimodal/llava/test_llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0


import pytest
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration

import forge
from forge.verify.verify import verify

from .utils import load_inputs
from test.models.utils import Framework, Source, Task, build_module_name


class Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, input_ids, attention_mask, pixel_values):
inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
output = self.model(**inputs)
return output.logits


def load_model(variant):
processor = AutoProcessor.from_pretrained(variant)
model = LlavaForConditionalGeneration.from_pretrained(variant)
model = Wrapper(model)
return model, processor


variants = ["llava-hf/llava-1.5-7b-hf"]


@pytest.mark.nightly
@pytest.mark.parametrize("variant", variants, ids=variants)
def test_llava(record_forge_property, variant):
# Build Module Name
module_name = build_module_name(
framework=Framework.PYTORCH,
model="llava",
variant=variant,
task=Task.CONDITIONAL_GENERATION,
source=Source.HUGGINGFACE,
)

# Record Forge Property
record_forge_property("model_name", module_name)

framework_model, processor = load_model(variant)
image = "https://www.ilankelman.org/stopsigns/australia.jpg"
text = "What’s shown in this image?"

# Input sample
input_ids, attn_mask, pixel_values = load_inputs(image, text, processor)
inputs = [input_ids, attn_mask, pixel_values]

# Forge compile framework model
compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name)

# Model Verification
verify(inputs, framework_model, compiled_model)
4 changes: 4 additions & 0 deletions forge/test/models/pytorch/multimodal/llava/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
from .utils import load_inputs
39 changes: 39 additions & 0 deletions forge/test/models/pytorch/multimodal/llava/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import re

import requests
from PIL import Image


def is_url(url):
regex = r"^(https?)://[^\s/$.?#].[^\s]*$"
return bool(re.match(regex, url))


def load_inputs(inp_image, text, processor):
conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": text},
],
}
]
text_prompt = processor.apply_chat_template(conversation, padding=True, add_generation_prompt=True)
if is_url(inp_image):
image = Image.open(requests.get(inp_image, stream=True).raw)
else:
if os.path.isfile(inp_image):
image = Image.open(inp_image)
else:
raise ValueError("Input is neither a valid URL nor a valid file path.")

inputs = processor(images=image, text=text_prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
attn_mask = inputs["attention_mask"]
pixel_values = inputs["pixel_values"]

return input_ids, attn_mask, pixel_values
1 change: 1 addition & 0 deletions forge/test/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Task(StrEnum):
OBJECT_DETECTION = "obj_det"
SEMANTIC_SEGMENTATION = "sem_seg"
MASKED_IMAGE_MODELLING = "masked_img"
CONDITIONAL_GENERATION = "cond_gen"
IMAGE_ENCODING = "img_enc"
VISUAL_BACKBONE = "visual_bb"

Expand Down

0 comments on commit b0a7d77

Please sign in to comment.