Skip to content

Commit

Permalink
Fix export of seq2seq model for optimum v1.7.3 (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix authored Mar 24, 2023
1 parent ceb7780 commit 3ca307c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
43 changes: 31 additions & 12 deletions optimum/intel/openvino/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def forward(
) -> BaseModelOutput:
self._create_inference_request()

# Check if inputs are c-like, if not - convert them.
# Check if inputs are c-like, if not - convert them
input_ids = _contiguous_helper(np.array(input_ids))

inputs = {
Expand Down Expand Up @@ -368,6 +368,15 @@ def __init__(self, model: openvino.runtime.Model, device: str, ov_config: Dict):
self.device = torch.device("cpu")
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
is_legacy = any("past_key_values" in key.get_any_name() for key in self.model.outputs)

if len(self.key_value_input_names) > 0 and not is_legacy:
self.use_past = True
self.num_pkv = 2
else:
self.use_past = False
self.num_pkv = 4

self.ov_config = ov_config
self.request = None

Expand All @@ -385,18 +394,19 @@ def forward(

if past_key_values is not None:
# Flatten the past_key_values
past_key_values = [
past_key_values = tuple(
_contiguous_helper(np.array(past_key_value))
for pkv_per_layer in past_key_values
for past_key_value in pkv_per_layer
]
)

# Add the past_key_values to the decoder inputs
inputs = {
input_name: Tensor(past_key_value, shared_memory=True)
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values)
}

# Check if inputs are c-like, if not - convert them.
# Check if inputs are c-like, if not - convert them
input_ids = _contiguous_helper(np.array(input_ids))
inputs["input_ids"] = Tensor(input_ids, shared_memory=True)

Expand All @@ -420,22 +430,31 @@ def forward(
output_name = "logits" if "logits" in output_names else next(iter(output_names))
outputs[output_name] = value.data

logits = torch.from_numpy(outputs["logits"]).to(self.device)

# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
past_key_values = tuple(
out_past_key_values = tuple(
torch.from_numpy(outputs[key]).to(self.device)
for key in outputs
if ("key_values" in key or "present" in key)
)

# Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and
# cross-attention per decoder layer
num_pkv = 4
past_key_values = tuple(past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv))

logits = torch.from_numpy(outputs["logits"]).to(self.device)
# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
# * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention)
# * 2 for the decoder with cache (k/v of self-attention as cross-attention cache is constant)
if self.use_past is False:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv)
)
else:
# grab the cross attention key/values from the inputs
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv] + past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv]
for i in range(0, len(out_past_key_values), self.num_pkv)
)

return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)
return Seq2SeqLMOutput(logits=logits, past_key_values=out_past_key_values)

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
assert False, "Error: Could not open '%s' due %s\n" % (filepath, error)

INSTALL_REQUIRE = [
"optimum>=1.7.0",
"optimum>=1.7.3",
"transformers>=4.20.0",
"datasets>=1.4.0",
"torch",
Expand Down

0 comments on commit 3ca307c

Please sign in to comment.