-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
concurrency without model cloning #573
base: main
Are you sure you want to change the base?
Changes from all commits
a2379a9
6798a66
02a129a
be1a32d
fe71151
6667661
4bb1370
ae47a4e
9fdd833
bd106a7
93e40ee
bb40376
e5b6075
f7653b2
8de1511
0ecaa0f
aec45a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -14,9 +14,10 @@ | |||
|
||||
import logging | ||||
import os | ||||
from dataclasses import dataclass | ||||
from pathlib import Path | ||||
from tempfile import TemporaryDirectory | ||||
from typing import Dict, Optional, Tuple, Union | ||||
from typing import Any, Dict, Optional, Tuple, Union | ||||
|
||||
import numpy as np | ||||
import openvino | ||||
|
@@ -26,7 +27,7 @@ | |||
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig | ||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward | ||||
from transformers.generation import GenerationMixin | ||||
from transformers.modeling_outputs import CausalLMOutputWithPast | ||||
from transformers.utils import ModelOutput | ||||
|
||||
from optimum.utils.normalized_config import NormalizedConfigManager | ||||
|
||||
|
@@ -44,6 +45,25 @@ | |||
core = Core() | ||||
|
||||
|
||||
@dataclass | ||||
class OVCausalLMOutputWithPast(ModelOutput): | ||||
""" | ||||
Base class for causal language model (or autoregressive) outputs. | ||||
|
||||
Args: | ||||
infer_request(`openvino.runtime.InferRequest` to be reused in the generation cycles. | ||||
beam_idx (`torch.Tensor` beam search algorimth context for the generation using stateful models | ||||
""" | ||||
|
||||
loss: Optional[torch.FloatTensor] = None | ||||
logits: torch.FloatTensor = None | ||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | ||||
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None | ||||
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None | ||||
infer_request: Optional[openvino.runtime.InferRequest] = None | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we rename it to something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. current name is aligned with openvino api name, so for me infer_request sounds better There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that'd be clearer for users who are not familiar with the openvino ecosystem, also we don't use |
||||
past_length: Optional[int] = None | ||||
|
||||
|
||||
TEXT_GENERATION_EXAMPLE = r""" | ||||
Example of text generation: | ||||
```python | ||||
|
@@ -118,8 +138,7 @@ def __init__( | |||
self.key_value_output_names = [key for key in self.output_names if "present" in key] | ||||
self._original_model = self.model.clone() # keep original model for serialization | ||||
self._pkv_precision = Type.f32 | ||||
self.next_beam_idx = None | ||||
self._past_length = 0 | ||||
|
||||
self.update_pkv_precision() | ||||
if self.is_dynamic: | ||||
self.model = self._reshape(self.model, -1, -1) | ||||
|
@@ -197,6 +216,7 @@ def update_pkv_precision(self, force_fp32=False): | |||
if self.is_dynamic: | ||||
self.model = self._reshape(self.model, -1, -1) | ||||
self.request = None | ||||
self.compiled_model = None | ||||
|
||||
def _save_pretrained(self, save_directory: Union[str, Path]): | ||||
""" | ||||
|
@@ -322,6 +342,7 @@ def normalized_config(self): | |||
def compile(self): | ||||
if self.request is None: | ||||
super().compile() | ||||
self.compiled_model = self.request | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it could make sense to also set |
||||
self.request = self.request.create_infer_request() | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not remove this
Suggested change
and use self.request instead of |
||||
|
||||
def _make_stateful(self): | ||||
|
@@ -354,12 +375,12 @@ def prepare_inputs( | |||
attention_mask: Optional[torch.LongTensor] = None, | ||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | ||||
position_ids: Optional[torch.LongTensor] = None, | ||||
past_length: Optional[int] = 0, | ||||
**kwargs, | ||||
) -> Dict: | ||||
batch_size = input_ids.shape[0] | ||||
if self.config.model_type == "bloom": | ||||
batch_size *= self.config.num_attention_heads | ||||
|
||||
inputs = {} | ||||
if not self.stateful: | ||||
if past_key_values is not None: | ||||
|
@@ -395,17 +416,6 @@ def prepare_inputs( | |||
else: | ||||
shape[1] = 0 | ||||
inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape()) | ||||
else: | ||||
# past_key_values are not used explicitly, instead they are handled inside the model | ||||
if past_key_values is None: | ||||
# This is the first iteration in a sequence, reset all states | ||||
if self.request is not None: | ||||
self.request.reset_state() | ||||
# Set initial value for the next beam_idx input that will be used at the current iteration | ||||
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used | ||||
self.next_beam_idx = np.arange(batch_size, dtype=int) | ||||
self._past_length = 0 | ||||
past_len = self._get_past_length(past_key_values) | ||||
|
||||
inputs["input_ids"] = np.array(input_ids) | ||||
# Add the attention_mask inputs when needed | ||||
|
@@ -414,7 +424,7 @@ def prepare_inputs( | |||
attention_mask = np.array(attention_mask) | ||||
else: | ||||
attention_mask = np.ones( | ||||
(input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype | ||||
(input_ids.shape[0], input_ids.shape[1] + past_length), dtype=inputs["input_ids"].dtype | ||||
) | ||||
|
||||
if "attention_mask" in self.input_names: | ||||
|
@@ -432,9 +442,11 @@ def prepare_inputs( | |||
inputs["position_ids"] = position_ids | ||||
|
||||
if "beam_idx" in self.input_names: | ||||
inputs["beam_idx"] = ( | ||||
self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int) | ||||
) | ||||
if past_key_values is not None: | ||||
if len(past_key_values[0]) > 0: | ||||
inputs["beam_idx"] = past_key_values[0] | ||||
return inputs | ||||
inputs["beam_idx"] = np.arange(batch_size, dtype=int) | ||||
|
||||
return inputs | ||||
|
||||
|
@@ -444,33 +456,39 @@ def forward( | |||
attention_mask: Optional[torch.LongTensor] = None, | ||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | ||||
position_ids: Optional[torch.LongTensor] = None, | ||||
infer_request: Optional[openvino.runtime.InferRequest] = None, | ||||
past_length: Optional[int] = 0, | ||||
**kwargs, | ||||
) -> CausalLMOutputWithPast: | ||||
) -> OVCausalLMOutputWithPast: | ||||
self.compile() | ||||
echarlaix marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
|
||||
inputs = self.prepare_inputs( | ||||
input_ids=input_ids, | ||||
attention_mask=attention_mask, | ||||
past_key_values=past_key_values, | ||||
position_ids=position_ids, | ||||
past_length=past_length, | ||||
**kwargs, | ||||
) | ||||
|
||||
# Run inference | ||||
self.request.start_async(inputs, share_inputs=True) | ||||
self.request.wait() | ||||
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device) | ||||
if infer_request is None: | ||||
self.compile() | ||||
infer_request = self.compiled_model.create_infer_request() | ||||
|
||||
infer_request.start_async(inputs, share_inputs=True) | ||||
infer_request.wait() | ||||
logits = torch.from_numpy(infer_request.get_tensor("logits").data).to(self.device) | ||||
if self.stateful: | ||||
# Need a marker to differentiate the first generate iteration from the others in | ||||
# the first condition at the function beginning above. | ||||
# It should be something that is not None and it should be True when converted to Boolean. | ||||
past_key_values = ((),) | ||||
self._past_length += input_ids.shape[1] | ||||
past_length += input_ids.shape[1] | ||||
|
||||
if not self.stateful: | ||||
if self.use_cache: | ||||
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer) | ||||
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) | ||||
past_key_values = tuple(infer_request.get_tensor(key).data for key in self.key_value_output_names) | ||||
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: | ||||
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention) | ||||
past_key_values = tuple( | ||||
|
@@ -479,27 +497,52 @@ def forward( | |||
else: | ||||
past_key_values = None | ||||
|
||||
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) | ||||
return OVCausalLMOutputWithPast( | ||||
logits=logits, past_key_values=past_key_values, infer_request=infer_request, past_length=past_length | ||||
) | ||||
|
||||
def _update_model_kwargs_for_generation( | ||||
self, | ||||
outputs: OVCausalLMOutputWithPast, | ||||
model_kwargs: Dict[str, Any], | ||||
is_encoder_decoder: bool = False, | ||||
standardize_cache_format: bool = False, | ||||
) -> Dict[str, Any]: | ||||
model_kwargs = super()._update_model_kwargs_for_generation( | ||||
outputs=outputs, | ||||
model_kwargs=model_kwargs, | ||||
is_encoder_decoder=is_encoder_decoder, | ||||
standardize_cache_format=standardize_cache_format, | ||||
) | ||||
if "infer_request" in outputs: | ||||
model_kwargs["infer_request"] = outputs["infer_request"] | ||||
if "past_length" in outputs: | ||||
model_kwargs["past_length"] = outputs["past_length"] | ||||
return model_kwargs | ||||
|
||||
# Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation | ||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): | ||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly | ||||
attention_mask = kwargs.get("attention_mask", None) | ||||
use_cache = kwargs.get("use_cache", None) | ||||
|
||||
infer_request = kwargs.get("infer_request", None) | ||||
past_length = kwargs.get("past_length", 0) | ||||
|
||||
if past_key_values is not None: | ||||
past_len = self._get_past_length(past_key_values) | ||||
past_length = self._get_past_length(past_key_values, past_length=past_length) | ||||
# Keep only the unprocessed tokens: | ||||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | ||||
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | ||||
# input) | ||||
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | ||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_len) :] | ||||
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | ||||
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | ||||
# input_ids based on the past_length. | ||||
elif past_len < input_ids.shape[1]: | ||||
input_ids = input_ids[:, past_len:] | ||||
elif past_length < input_ids.shape[1]: | ||||
input_ids = input_ids[:, past_length:] | ||||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens | ||||
# >>>>>>> origin/main | ||||
position_ids = kwargs.get("position_ids", None) | ||||
if attention_mask is not None and position_ids is None and "position_ids" in self.input_names: | ||||
# create position_ids on the fly for batch generation | ||||
|
@@ -512,15 +555,17 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg | |||
"input_ids": input_ids, | ||||
"past_key_values": past_key_values, | ||||
"use_cache": use_cache, | ||||
"infer_request": infer_request, | ||||
"position_ids": position_ids, | ||||
"attention_mask": attention_mask, | ||||
"past_length": past_length, | ||||
} | ||||
|
||||
def _get_past_length(self, past_key_values=None): | ||||
def _get_past_length(self, past_key_values=None, past_length=0): | ||||
if past_key_values is None: | ||||
return 0 | ||||
if self.stateful: | ||||
return self._past_length | ||||
return past_length | ||||
if self.config.model_type in MULTI_QUERY_ATTN_MODELS: | ||||
return past_key_values[0].shape[-2] | ||||
seq_length_dim = -2 | ||||
|
@@ -546,8 +591,10 @@ def _reorder_cache( | |||
if self.stateful: | ||||
# TODO: Apply it differently based on model type | ||||
# TODO: At least for bloom we need to replicate values for each attention head | ||||
self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration | ||||
return past_key_values | ||||
# save beam_idx and infer_request to be used as an input in the next iteration | ||||
# here, beam_idx content is passed inside the past_key_values | ||||
|
||||
return ((beam_idx),) | ||||
else: | ||||
return tuple( | ||||
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values | ||||
|
@@ -675,8 +722,8 @@ def _reorder_cache( | |||
batch_size = beam_idx.shape[0] | ||||
indices = np.array(range(batch_size * self.config.num_attention_heads)) | ||||
indices = indices.reshape([batch_size, self.config.num_attention_heads]) | ||||
self.next_beam_idx = np.take(indices, beam_idx, 0).flatten() | ||||
return past_key_values | ||||
next_beam_idx = np.take(indices, beam_idx, 0).flatten() | ||||
return ((next_beam_idx),) | ||||
else: | ||||
standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) | ||||
reordered_past = tuple( | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure why we need a new attribute here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is needed to create new infer_request in the context of generate method for each concurrent thread. So far we had in the model class request attribute which was pointing to a static infer_request and can not be used to allocate new request. Generally there is a bit confusing setup when the request attribute is set to the compiled_model object in the based class but latest it is overwritten to become the infer_request. Eventually the recommendation would be to switch to using compiled_model attribute instead and create infer_requests dynamically. It was proposed to make this switch in a separate PR.