Skip to content

Commit

Permalink
update aquila to support v1 and v2
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed May 17, 2024
1 parent 3db832a commit 17f912a
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 13 deletions.
50 changes: 45 additions & 5 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
past_key_shape = (self.batch_size, self.sequence_length, self.num_attention_heads, self.kv_channels)
past_value_shape = (self.batch_size, self.sequence_length, self.num_attention_heads, self.kv_channels)
return [
(
Mistral(
self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
for _ in range(self.nuMistralm_layers)
]


Expand Down Expand Up @@ -658,13 +658,53 @@ class XGLMConfig(TextDecoderWithPositionIdsOnnxConfig):
)


class AquilaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task,
normalized_config,
batch_size,
sequence_length,
random_batch_size_range,
random_sequence_length_range,
**kwargs,
)
self.num_key_value_heads = getattr(
normalized_config, "num_key_value_heads", normalized_config.num_attention_heads
)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = (
self.batch_size,
self.num_key_value_heads,
self.sequence_length,
self.hidden_size // self.num_attention_heads,
)
return [
(
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]


@register_in_tasks_manager("aquila", *["text-generation", "text-generation-with-past"], library_name="transformers")
class AquilaMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, AquilaDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = AquilaDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
Expand Down
19 changes: 12 additions & 7 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
)
bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
Expand All @@ -1120,8 +1120,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, getattr(self, "num_key_value_heads", self.num_heads), self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, getattr(self, "num_key_value_heads", self.num_heads), self.head_dim
).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
Expand All @@ -1136,9 +1140,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):

past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if hasattr(self, "num_key_value_groups"):
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, scale=(1 / math.sqrt(self.head_dim))
Expand All @@ -1148,7 +1153,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

if self.config.pretraining_tp > 1:
if hasattr(self.config, "pretraining_tp") and self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
Expand Down
2 changes: 2 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"cohere",
"xglm",
"aquila",
"aquila2",
"xverse",
"internlm",
)
Expand All @@ -573,6 +574,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"orion",
"phi3",
"aquila",
"aquila2",
"xverse",
"internlm",
)
Expand Down
3 changes: 2 additions & 1 deletion tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

MODEL_NAMES = {
"albert": "hf-internal-testing/tiny-random-albert",
"aquila": "katuni4ka/tiny-random-aquila",
"aquila": "katuni4ka/tiny-random-aquilachat",
"aquila2": "katuni4ka/tiny-random-aquila2",
"audio_spectrogram_transformer": "Ericwang/tiny-random-ast",
"bge": "BAAI/bge-small-en-v1.5",
"beit": "hf-internal-testing/tiny-random-BeitForImageClassification",
Expand Down

0 comments on commit 17f912a

Please sign in to comment.