From cbbe95aa4992f58ff62ddc7e6c56fe001f8e4bc0 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 14 Nov 2024 12:42:39 -0800 Subject: [PATCH] Formatting --- server/lorax_server/models/causal_lm.py | 4 +- .../models/custom_modeling/mllama.py | 16 ++---- server/lorax_server/models/mllama.py | 57 ++++++++++--------- server/lorax_server/models/seq2seq_lm.py | 4 +- server/lorax_server/utils/lora.py | 4 +- server/lorax_server/utils/tokens.py | 4 +- 6 files changed, 47 insertions(+), 42 deletions(-) diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 7eeaa9f36..d0968c32a 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -679,7 +679,9 @@ def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Option else: seed = None - generated_text = GeneratedText(output_text, stopping_criteria.current_tokens, reason, seed) + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, stopping_criteria.current_skipped, reason, seed + ) else: generated_text = None diff --git a/server/lorax_server/models/custom_modeling/mllama.py b/server/lorax_server/models/custom_modeling/mllama.py index c9ec9dbf7..5c367ac7f 100644 --- a/server/lorax_server/models/custom_modeling/mllama.py +++ b/server/lorax_server/models/custom_modeling/mllama.py @@ -225,11 +225,7 @@ def __init__(self, *, prefix, config, weights, layer_id, model_type): out_size = fc1.linear.weight.shape[-1] * weights.process_group.size() self.fc1 = TensorParallelMultiAdapterLinear.load( - fc1, - layer_id, - [f'{model_type}_{FC1}'], - sizes=[out_size], - process_group=weights.process_group + fc1, layer_id, [f"{model_type}_{FC1}"], sizes=[out_size], process_group=weights.process_group ) self.fc2 = TensorParallelAdapterRowLinear.load( TensorParallelRowLinear.load( @@ -239,7 +235,7 @@ def __init__(self, *, prefix, config, weights, layer_id, model_type): bias=True, ), layer_id, - f'{model_type}_{FC2}', + f"{model_type}_{FC2}", process_group=weights.process_group, ) @@ -261,7 +257,7 @@ def load_attention(config, prefix, weights, layer_id, model_type, head_dim, n_he return TensorParallelMultiAdapterLinear.load( base_layer, layer_id, - [f'{model_type}_{Q_PROJ}', f'{model_type}_{K_PROJ}', f'{model_type}_{V_PROJ}'], + [f"{model_type}_{Q_PROJ}", f"{model_type}_{K_PROJ}", f"{model_type}_{V_PROJ}"], sizes=[ head_dim * n_head, head_dim * n_head_kv, @@ -306,7 +302,7 @@ def __init__(self, *, prefix, config, weights, layer_id, model_type): bias=False, ), layer_id, - f'{model_type}_{O_PROJ}', + f"{model_type}_{O_PROJ}", process_group=weights.process_group, ) @@ -557,7 +553,7 @@ def __init__(self, *, prefix, config, weights): weights=weights, is_gated=False, num_layers=config.num_hidden_layers, - model_type='VISION_TRANSFORMER', + model_type="VISION_TRANSFORMER", ) self.global_transformer = MllamaVisionEncoder( prefix=f"{prefix}.global_transformer", @@ -565,7 +561,7 @@ def __init__(self, *, prefix, config, weights): weights=weights, is_gated=True, num_layers=config.num_global_layers, - model_type='VISION_GLOBAL_TRANSFORMER', + model_type="VISION_GLOBAL_TRANSFORMER", ) def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: diff --git a/server/lorax_server/models/mllama.py b/server/lorax_server/models/mllama.py index 504d8129a..4c514f508 100644 --- a/server/lorax_server/models/mllama.py +++ b/server/lorax_server/models/mllama.py @@ -24,6 +24,7 @@ TEXT_ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD] VISION_ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, FC1, FC2] + @dataclass class MllamaCausalLMBatch(VlmCausalLMBatch): image_indices: List[int] = 42 @@ -179,33 +180,34 @@ def from_pb( class MllamaCausalLM(VlmCausalLM): - @property def supports_adapter_loading(self) -> bool: return True @property def adapter_layers(self) -> List[str]: - return [f'TEXT_{layer_type}' for layer_type in TEXT_ADAPTER_LAYERS] \ - + [f'VISION_GLOBAL_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS] \ - + [f'VISION_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS] + return ( + [f"TEXT_{layer_type}" for layer_type in TEXT_ADAPTER_LAYERS] + + [f"VISION_GLOBAL_TRANSFORMER_{layer_type}" for layer_type in VISION_ADAPTER_LAYERS] + + [f"VISION_TRANSFORMER_{layer_type}" for layer_type in VISION_ADAPTER_LAYERS] + ) @property def default_traced_adapter_layers(self) -> List[str]: return [Q_PROJ, V_PROJ] def get_num_layers_for_type(self, layer_type: str) -> int: - if 'LM_HEAD' in layer_type: + if "LM_HEAD" in layer_type: return 1 - if 'TEXT_' in layer_type: + if "TEXT_" in layer_type: return [ layer_id for layer_id, layer in enumerate(self.model.text_model.model.layers) - if not isinstance(layer, FlashLlamaCrossLayer) + if not isinstance(layer, FlashLlamaCrossLayer) ] - if 'VISION_GLOBAL_TRANSFORMER_' in layer_type: + if "VISION_GLOBAL_TRANSFORMER_" in layer_type: return len(self.model.vision_model.global_transformer.layers) - if 'VISION_TRANSFORMER_' in layer_type: + if "VISION_TRANSFORMER_" in layer_type: return len(self.model.vision_model.transformer.layers) def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: @@ -215,51 +217,54 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: for i, layer in enumerate(self.model.text_model.model.layers): if isinstance(layer, FlashLlamaCrossLayer): continue - layer_weights[(i, f'TEXT_{Q_PROJ}')] = ( + layer_weights[(i, f"TEXT_{Q_PROJ}")] = ( f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, f'TEXT_{K_PROJ}')] = ( + layer_weights[(i, f"TEXT_{K_PROJ}")] = ( f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, f'TEXT_{V_PROJ}')] = ( + layer_weights[(i, f"TEXT_{V_PROJ}")] = ( f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, f'TEXT_{O_PROJ}')] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) - - layer_weights[(i, f'TEXT_{GATE_PROJ}')] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj) - layer_weights[(i, f'TEXT_{UP_PROJ}')] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj) - layer_weights[(i, f'TEXT_{DOWN_PROJ}')] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) - layer_weights[(0, f'TEXT_{LM_HEAD}')] = ("base_model.model.language_model.lm_head", self.model.text_model.lm_head) + layer_weights[(i, f"TEXT_{O_PROJ}")] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) + + layer_weights[(i, f"TEXT_{GATE_PROJ}")] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj) + layer_weights[(i, f"TEXT_{UP_PROJ}")] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj) + layer_weights[(i, f"TEXT_{DOWN_PROJ}")] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) + layer_weights[(0, f"TEXT_{LM_HEAD}")] = ( + "base_model.model.language_model.lm_head", + self.model.text_model.lm_head, + ) vision_layer_mappings = [ ("vision_model.global_transformer.layers", self.model.vision_model.global_transformer.layers), ("vision_model.transformer.layers", self.model.vision_model.transformer.layers), ] for prefix, layer_list in vision_layer_mappings: - layer_type_prefix = 'VISION_GLOBAL_TRANSFORMER' if 'global_transformer' in prefix else 'VISION_TRANSFORMER' + layer_type_prefix = "VISION_GLOBAL_TRANSFORMER" if "global_transformer" in prefix else "VISION_TRANSFORMER" for i, layer in enumerate(layer_list): - layer_weights[(i, f'{layer_type_prefix}_{Q_PROJ}')] = ( + layer_weights[(i, f"{layer_type_prefix}_{Q_PROJ}")] = ( f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.qkv_proj, ) - layer_weights[(i, f'{layer_type_prefix}_{K_PROJ}')] = ( + layer_weights[(i, f"{layer_type_prefix}_{K_PROJ}")] = ( f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.qkv_proj, ) - layer_weights[(i, f'{layer_type_prefix}_{V_PROJ}')] = ( + layer_weights[(i, f"{layer_type_prefix}_{V_PROJ}")] = ( f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.qkv_proj, ) - layer_weights[(i, f'{layer_type_prefix}_{O_PROJ}')] = ( + layer_weights[(i, f"{layer_type_prefix}_{O_PROJ}")] = ( f"{prefix}.{i}.self_attn.o_proj", - layer.self_attn.o_proj + layer.self_attn.o_proj, ) - layer_weights[(i, f'{layer_type_prefix}_{FC1}')] = (f"{prefix}.{i}.mlp.fc1", layer.mlp.fc1) - layer_weights[(i, f'{layer_type_prefix}_{FC2}')] = (f"{prefix}.{i}.mlp.fc2", layer.mlp.fc2) + layer_weights[(i, f"{layer_type_prefix}_{FC1}")] = (f"{prefix}.{i}.mlp.fc1", layer.mlp.fc1) + layer_weights[(i, f"{layer_type_prefix}_{FC2}")] = (f"{prefix}.{i}.mlp.fc2", layer.mlp.fc2) return layer_weights diff --git a/server/lorax_server/models/seq2seq_lm.py b/server/lorax_server/models/seq2seq_lm.py index ef2733363..fa4d3a0f3 100644 --- a/server/lorax_server/models/seq2seq_lm.py +++ b/server/lorax_server/models/seq2seq_lm.py @@ -654,7 +654,9 @@ def generate_token(self, batch: Seq2SeqLMBatch) -> Tuple[List[Generation], Optio else: seed = None - generated_text = GeneratedText(output_text, stopping_criteria.current_tokens, reason, seed) + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, stopping_criteria.current_skipped, reason, seed + ) else: generated_text = None diff --git a/server/lorax_server/utils/lora.py b/server/lorax_server/utils/lora.py index 0262f1c34..dbce2bfcc 100644 --- a/server/lorax_server/utils/lora.py +++ b/server/lorax_server/utils/lora.py @@ -8,7 +8,7 @@ UP_PROJ = "up_proj" DOWN_PROJ = "down_proj" -FC1 = 'fc1' -FC2 = 'fc2' +FC1 = "fc1" +FC2 = "fc2" LM_HEAD = "lm_head" diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 658ef252c..ff31b3ebc 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -187,11 +187,11 @@ def __init__( self.current_output = "" self.current_skipped = 0 self.ignore_eos_token = ignore_eos_token - + def __call__(self, last_token: int, last_output: str, skipped: bool = False) -> Tuple[bool, Optional[str]]: if skipped: self.current_skipped += 1 - + self.current_tokens += 1 if self.current_tokens >= self.max_new_tokens: return True, FinishReason.FINISH_REASON_LENGTH