Skip to content
This repository has been archived by the owner on Feb 3, 2025. It is now read-only.

Converting a Vision Transformer model with pre-built engines #319

Open
sayakpaul opened this issue Aug 15, 2022 · 0 comments
Open

Converting a Vision Transformer model with pre-built engines #319

sayakpaul opened this issue Aug 15, 2022 · 0 comments

Comments

@sayakpaul
Copy link

sayakpaul commented Aug 15, 2022

System information

NVIDIA

$ nvidia-smi
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   56C    P0    28W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

I am using an NGC container to perform my stuff. Here's how I am running the Docker image:

$ nvidia-docker run -it --rm --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 nvcr.io/nvidia/tensorflow:22.07-tf2-py3

After this, I get terminal access to the container.

TensorFlow build details within the container

tf.sysconfig.get_build_info()

OrderedDict([('cpu_compiler', '/opt/rh/devtoolset-9/root/usr/bin/gcc'), ('cuda_compute_capabilities', ['compute_86', 'sm_52', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86']), ('cuda_version', '11.7'), ('cudnn_version', '8'), ('is_cuda_build', True), ('is_rocm_build', False), ('is_tensorrt_build', True)])

Issue

I am trying to convert a ViT B-16 model from transformers. First, I am serializing it as a SavedModel resource:

from transformers import ViTFeatureExtractor, TFViTForImageClassification
import tensorflow as tf
import transformers
import tempfile
import requests
import base64
import json
import os

LOCAL_MODEL_DIR = "vit"


def normalize_img(img, mean=feature_extractor.image_mean, std=feature_extractor.image_std):
    # Scale to the value range of [0, 1] first and then normalize.
    img = img / 255
    mean = tf.constant(mean)
    std = tf.constant(std)
    return (img - mean) / std

def preprocess(string_input):
    decoded = tf.io.decode_jpeg(string_input, channels=3)
    resized = tf.image.resize(decoded, size=(SIZE, SIZE))
    normalized = normalize_img(resized)
    normalized = tf.transpose(normalized, (2, 0, 1)) # Since HF models are channel-first.
    return normalized


@tf.function(input_signature=[tf.TensorSpec([None], tf.string)])
def preprocess_fn(string_input):
    decoded_images = tf.map_fn(
        preprocess, string_input, dtype=tf.float32, back_prop=False
    )
    return {CONCRETE_INPUT: decoded_images}


def model_exporter(model: tf.keras.Model):
    m_call = tf.function(model.call).get_concrete_function(
        tf.TensorSpec(
            shape=[None, 3, SIZE, SIZE], dtype=tf.float32, name=CONCRETE_INPUT
        )
    )

    @tf.function(input_signature=[tf.TensorSpec([None], tf.string)])
    def serving_fn(string_input):
        labels = tf.constant(
            list(model.config.id2label.values()), dtype=tf.string
        )
        images = preprocess_fn(string_input)

        predictions = m_call(**images)
        indices = tf.argmax(predictions.logits, axis=1)
        pred_source = tf.gather(params=labels, indices=indices)
        probs = tf.nn.softmax(predictions.logits, axis=1)
        pred_confidence = tf.reduce_max(probs, axis=1)
        return {"label": pred_source, "confidence": pred_confidence}

    return serving_fn



# the saved_model parameter is a flag to create a saved model version of the model
print("Loading model.")
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
print("Model loaded.")


# Determine model variables.
feature_extractor = ViTFeatureExtractor()
concerete_input = "pixel_values"
size = feature_extractor.size
input_shape = (SIZE, SIZE, 3)


print("Saving model.")
tf.saved_model.save(
    model,
    LOCAL_MODEL_DIR,
    signatures={"serving_default": model_exporter(model)},
)
print("Model saved.")

Then conversion:

import glob

import tensorflow as tf

LOCAL_IMAGE_PATH = "imagenette-validation-samples"
ORIGINAL_MODEL_PATH = "vit"
TENSORRT_MODEL_DIR = "vit-tensorrt"


def convert_to_string(image_path):
    with open(image_path, "rb") as f:
        image_string = f.read()
    return image_string


def calibration_input_fn(image_bytes):
    def fn():
        for img_bytes in image_bytes:
            yield tf.convert_to_tensor(img_bytes)
    return fn


all_image_paths = glob.glob(f"{LOCAL_IMAGE_PATH}/*.png")
print(f"Total images found: {len(all_image_paths)}.")

all_images_bytes = [[convert_to_string(image_path)] for image_path in all_image_paths]
print(f"Length of the totyal image image bytes: {len(all_images_bytes)}.")


params = tf.experimental.tensorrt.ConversionParams(
    precision_mode="FP16",
    max_workspace_size_bytes=2 << 32,  # 8,589,934,592 bytes
    maximum_cached_engines=100,
    minimum_segment_size=3,
    allow_build_at_runtime=True,
)
converter = tf.experimental.tensorrt.Converter(
    input_saved_model_dir=ORIGINAL_MODEL_PATH, conversion_params=params
)
converter.convert()
converter.build(input_fn=calibration_input_fn(all_images_bytes))
converter.save(TENSORRT_MODEL_DIR)

To get the imagenette-validation-samples directory, run the following from the container:

$ wget https://github.com/sayakpaul/deploy-hf-tf-vision-models/releases/download/3.0/imagenette-validation-samples.tar.gz
$ tar xf imagenette-validation-samples.tar.gz

When running conversion, I am getting:

Traceback (most recent call last):
  File "convert_to_tensor.py", line 41, in <module>
    converter.build(input_fn=calibration_input_fn(all_images_bytes))
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/compiler/tensorrt/trt_convert.py", line 1447, in build
    func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/function.py", line 1602, in __call__
    return self._call_impl(args, kwargs)
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/wrap_function.py", line 243, in _call_impl
    return super(WrappedFunction, self)._call_impl(
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/function.py", line 1620, in _call_impl
    return self._call_with_flat_signature(args, kwargs, cancellation_manager)
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/function.py", line 1669, in _call_with_flat_signature
    return self._call_flat(args, self.captured_inputs, cancellation_manager)
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/function.py", line 1860, in _call_flat
    return self._build_call_outputs(self._inference_function.call(
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/function.py", line 497, in call
    outputs = execute.execute(
  File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node 'StatefulPartitionedCall/PartitionedCall/map/TensorArrayUnstack/TensorListFromTensor' defined at (most recent call last):
    File "convert_to_tensor.py", line 40, in <module>
      converter.convert()
Node: 'StatefulPartitionedCall/PartitionedCall/map/TensorArrayUnstack/TensorListFromTensor'
Detected at node 'StatefulPartitionedCall/PartitionedCall/map/TensorArrayUnstack/TensorListFromTensor' defined at (most recent call last):
    File "convert_to_tensor.py", line 40, in <module>
      converter.convert()
Node: 'StatefulPartitionedCall/PartitionedCall/map/TensorArrayUnstack/TensorListFromTensor'
2 root error(s) found.
  (0) INVALID_ARGUMENT:  Tensor must be at least a vector, but saw shape: []
         [[{{node StatefulPartitionedCall/PartitionedCall/map/TensorArrayUnstack/TensorListFromTensor}}]]
         [[StatefulPartitionedCall/GatherV2/_426]]
  (1) INVALID_ARGUMENT:  Tensor must be at least a vector, but saw shape: []
         [[{{node StatefulPartitionedCall/PartitionedCall/map/TensorArrayUnstack/TensorListFromTensor}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_pruned_43026]

Without converter.build() the conversion succeeds but the latency is higher.

Notes

I made the model accept compressed image string to reduce request payload sizes.

What am I missing out on?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant