Skip to content

Commit

Permalink
OpenVINO Seq2Seq pipeline improvements (#131)
Browse files Browse the repository at this point in the history
* Add async mode and shared memory

* Fixes and codestyle

* Add shared_memory helper
  • Loading branch information
Jan Iwaszkiewicz authored Dec 25, 2022
1 parent c501878 commit 2936d24
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 20 deletions.
2 changes: 1 addition & 1 deletion optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self.model_save_dir = kwargs.get("model_save_dir")
self._device = kwargs.get("device", "CPU")
self.is_dynamic = kwargs.get("dynamic_shapes", True)
self.ov_config = {"PERFORMANCE_HINT": "LATENCY"}
self.ov_config = {}
if "GPU" in self._device:
raise ValueError("Support of dynamic shapes for GPU devices is not yet available.")
if self.is_dynamic:
Expand Down
64 changes: 45 additions & 19 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
from pathlib import Path
from typing import Dict, Optional, Tuple

import numpy as np
import torch
import transformers
from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput

import openvino
from openvino.runtime import Core
from openvino.runtime import Core, Tensor

from ..utils.import_utils import is_transformers_version
from .modeling_base_seq2seq import OVBaseModelForSeq2SeqLM
Expand Down Expand Up @@ -126,6 +127,10 @@
"""


def _contiguous_helper(tensor: np.ndarray) -> np.ndarray:
return tensor if tensor.flags["C_CONTIGUOUS"] else np.ascontiguousarray(tensor)


@add_start_docstrings(
"""
Sequence-to-sequence model with a language modeling head for OpenVINO inference.
Expand Down Expand Up @@ -320,17 +325,23 @@ def forward(

self._create_inference_request()

# Check if inputs are c-like, if not - convert them.
input_ids = _contiguous_helper(input_ids.numpy())

inputs = {
"input_ids": input_ids,
"input_ids": Tensor(input_ids, shared_memory=True),
}

# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask
attention_mask = _contiguous_helper(attention_mask.numpy())
inputs["attention_mask"] = Tensor(attention_mask, shared_memory=True)

# Run inference
outputs = self.request.infer(inputs)
outputs = {key.get_any_name(): value for key, value in outputs.items()}
last_hidden_state = torch.from_numpy(outputs["last_hidden_state"]).to(self.device)
self.request.start_async(inputs)
self.request.wait()

last_hidden_state = torch.from_numpy(self.request.get_tensor("last_hidden_state").data).to(self.device)

return BaseModelOutput(last_hidden_state=last_hidden_state)

Expand Down Expand Up @@ -375,25 +386,40 @@ def forward(

self._create_inference_request()

inputs = {
"input_ids": input_ids,
"encoder_attention_mask": encoder_attention_mask,
}

# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names:
inputs["encoder_hidden_states"] = encoder_hidden_states
inputs = {}

if past_key_values is not None:
# Flatten the past_key_values
past_key_values = [past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer]
past_key_values = [
_contiguous_helper(past_key_value.numpy())
for pkv_per_layer in past_key_values
for past_key_value in pkv_per_layer
]
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
inputs[input_name] = past_key_value
inputs = {
input_name: Tensor(past_key_value, shared_memory=True)
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values)
}

# Check if inputs are c-like, if not - convert them.
input_ids = _contiguous_helper(input_ids.numpy())
encoder_attention_mask = _contiguous_helper(encoder_attention_mask.numpy())

inputs["input_ids"] = Tensor(input_ids, shared_memory=True)
inputs["encoder_attention_mask"] = Tensor(encoder_attention_mask, shared_memory=True)

# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names:
encoder_hidden_states = _contiguous_helper(encoder_hidden_states.numpy())
inputs["encoder_hidden_states"] = Tensor(encoder_hidden_states, shared_memory=True)

# Run inference
outputs = self.request.infer(inputs)
outputs = {key.get_any_name(): value for key, value in outputs.items()}
self.request.start_async(inputs)
self.request.wait()

outputs = {
key.get_any_name(): value.data for key, value in zip(self.request.model_outputs, self.request.outputs)
}

# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
Expand Down

0 comments on commit 2936d24

Please sign in to comment.