Skip to content

Commit

Permalink
Demo version of ResNet 50 (#1133)
Browse files Browse the repository at this point in the history
### Ticket
Fix #1132

### Problem description
Confirm that ResNet 50 works with valid demo (image classification). 

### What's changed
Created demo version of ResNet 50 that runs on a few images and outputs
a table for CPU comparison (predicted class and confidence score).

Additionally, unified existing ResNet tests into a single file.

### Checklist
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
nvukobratTT authored Jan 30, 2025
1 parent 2592480 commit 951cb38
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 225 deletions.
29 changes: 0 additions & 29 deletions forge/test/mlir/resnet/test_resnet_inference.py

This file was deleted.

145 changes: 0 additions & 145 deletions forge/test/mlir/resnet/test_resnet_unique_ops.py

This file was deleted.

125 changes: 78 additions & 47 deletions forge/test/models/pytorch/vision/resnet/test_resnet.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0
import random

import pytest
import requests
import timm
import torch
from datasets import load_dataset
from loguru import logger
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from tabulate import tabulate
from torchvision.models.resnet import resnet50
from transformers import AutoImageProcessor, ResNetForImageClassification

import forge
from forge.verify.config import VerifyConfig
from forge.verify.value_checkers import AutomaticValueChecker
from forge.verify.verify import verify

from test.models.utils import Framework, Source, Task, build_module_name
Expand All @@ -27,7 +28,9 @@
@pytest.mark.nightly
@pytest.mark.parametrize("variant", variants, ids=variants)
def test_resnet_hf(variant, record_forge_property):
# Record model properties
random.seed(0)

# Record model details
module_name = build_module_name(
framework=Framework.PYTORCH,
model="resnet",
Expand All @@ -37,69 +40,97 @@ def test_resnet_hf(variant, record_forge_property):
)
record_forge_property("model_name", module_name)

# Load dataset
dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
# Load tiny dataset
dataset = load_dataset("zh-plus/tiny-imagenet")
images = random.sample(dataset["valid"]["image"], 10)

# Load Torch model, preprocess image, and label dictionary
processor = download_model(AutoImageProcessor.from_pretrained, variant)
# Load framework model
framework_model = download_model(ResNetForImageClassification.from_pretrained, variant, return_dict=False)
label_dict = framework_model.config.id2label

inputs = processor(image, return_tensors="pt")
inputs = inputs["pixel_values"]
# Compile model
input_sample = [torch.rand(1, 3, 224, 224)]
compiled_model = forge.compile(framework_model, input_sample)

compiled_model = forge.compile(framework_model, inputs)
# Verify data on sample input
verify(input_sample, framework_model, compiled_model, VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.95)))

cpu_logits = framework_model(inputs)[0]
cpu_pred = label_dict[cpu_logits.argmax(-1).item()]
# Run model on sample data and print results
run_and_print_results(framework_model, compiled_model, images)

tt_logits = compiled_model(inputs)[0]
tt_pred = label_dict[tt_logits.argmax(-1).item()]

assert cpu_pred == tt_pred, f"Inference mismatch: CPU prediction: {cpu_pred}, TT prediction: {tt_pred}"
def run_and_print_results(framework_model, compiled_model, inputs):
"""
Runs inference using both a framework model and a compiled model on a list of input images,
then prints the results in a formatted table.
verify([inputs], framework_model, compiled_model)
Args:
framework_model: The original framework-based model.
compiled_model: The compiled version of the model.
inputs: A list of images to process and classify.
"""
label_dict = framework_model.config.id2label
processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")

results = []
for i, image in enumerate(inputs):
processed_inputs = processor(image, return_tensors="pt")["pixel_values"]

def generate_model_resnet_imgcls_timm_pytorch(variant):
# Load ResNet50 feature extractor and model from TIMM
model = download_model(timm.create_model, variant, pretrained=True)
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
cpu_logits = framework_model(processed_inputs)[0]
cpu_conf, cpu_idx = cpu_logits.softmax(-1).max(-1)
cpu_pred = label_dict.get(cpu_idx.item(), "Unknown")

# Load data sample
try:
url = "https://images.rawpixel.com/image_1300/cHJpdmF0ZS9sci9pbWFnZXMvd2Vic2l0ZS8yMDIyLTA1L3BkMTA2LTA0Ny1jaGltXzEuanBn.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
except:
logger.warning(
"Failed to download the image file, replacing input with random tensor. Please check if the URL is up to date"
)
image = torch.rand(1, 3, 256, 256)
tt_logits = compiled_model(processed_inputs)[0]
tt_conf, tt_idx = tt_logits.softmax(-1).max(-1)
tt_pred = label_dict.get(tt_idx.item(), "Unknown")

# Data preprocessing
pixel_values = transform(image).unsqueeze(0)
results.append([i + 1, cpu_pred, cpu_conf.item(), tt_pred, tt_conf.item()])

return model, [pixel_values], {}
print(
tabulate(
results,
headers=["Example", "CPU Prediction", "CPU Confidence", "Compiled Prediction", "Compiled Confidence"],
tablefmt="grid",
)
)


@pytest.mark.nightly
def test_resnet_timm(record_forge_property):
pytest.skip("Skipping due to the current CI/CD pipeline limitations")

# Build Module Name
# Record model details
module_name = build_module_name(
framework=Framework.PYTORCH, model="resnet", source=Source.TIMM, variant="50", task=Task.IMAGE_CLASSIFICATION
)
record_forge_property("model_name", module_name)

# Load framework model
framework_model = download_model(timm.create_model, "resnet50", pretrained=True)

# Record Forge Property
# Compile model
input_sample = [torch.rand(1, 3, 224, 224)]
compiled_model = forge.compile(framework_model, sample_inputs=input_sample, module_name=module_name)

# Verify data on sample input
verify(input_sample, framework_model, compiled_model, VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.95)))


@pytest.mark.nightly
def test_resnet_torchvision(record_forge_property):
# Record model details
module_name = build_module_name(
framework=Framework.PYTORCH,
model="resnet",
source=Source.TORCHVISION,
variant="50",
task=Task.IMAGE_CLASSIFICATION,
)
record_forge_property("model_name", module_name)

framework_model, inputs, _ = generate_model_resnet_imgcls_timm_pytorch("resnet50")
# Load framework model
framework_model = resnet50()

# Forge compile framework model
compiled_model = forge.compile(framework_model, sample_inputs=inputs, module_name=module_name)
# Compile model
input_sample = [torch.rand(1, 3, 224, 224)]
compiled_model = forge.compile(framework_model, input_sample)

# Model Verification
verify(inputs, framework_model, compiled_model)
# Verify data on sample input
verify(input_sample, framework_model, compiled_model, VerifyConfig(value_checker=AutomaticValueChecker(pcc=0.95)))
4 changes: 0 additions & 4 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ testpaths =
forge/test/mlir/llama/test_llama_inference.py::test_llama_inference
forge/test/mlir/llama/tests

# Resnet
forge/test/mlir/resnet/test_resnet_inference.py::test_resnet_inference
forge/test/mlir/resnet/test_resnet_unique_ops.py

# Benchmark
# MNIST Linear
forge/test/benchmark/benchmark/models/mnist_linear.py::test_mnist_linear
Expand Down

0 comments on commit 951cb38

Please sign in to comment.