diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index a8d15cd5332b..37da20005c03 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -32,7 +32,7 @@ from .runtime.hybrid_engine import DeepSpeedHybridEngine from .runtime.pipe.engine import PipelineEngine from .inference.engine import InferenceEngine -from .inference.config import DeepSpeedInferenceConfig +from .inference.config import DeepSpeedInferenceConfig, QuantizationConfig from .runtime.lr_schedules import add_tuning_arguments from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError from .runtime.activation_checkpointing import checkpointing @@ -359,6 +359,14 @@ def init_inference(model, config=None, **kwargs): raise ValueError(f"Conflicting argument '{key}' in 'config':{config_dict[key]} and kwargs:{kwargs[key]}") config_dict.update(kwargs) + # Set the number of weight quantization groups if an optional 'quantize_groups' argument is given + if "quantize_groups" in config_dict: + if not ("dtype", torch.int8) in config_dict.items(): + raise ValueError("'dtype' argument expected int8 when 'quantize_groups' argument is provided") + quant = QuantizationConfig() + quant.weight.q_groups = config_dict.pop("quantize_groups") + config_dict["quant"] = quant + ds_inference_config = DeepSpeedInferenceConfig(**config_dict) engine = InferenceEngine(model, config=ds_inference_config) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 32c88549c821..5c85b6c4340d 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -194,6 +194,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m """ # defining globals as internally defined functions inherit these everywhere quantize = (config.dtype == torch.int8) + quantize_groups = config.quant.weight.q_groups if quantize else 0 # todo: Refactor later. In future, let's minimize the style used above and use config.** instead linear_layer_setting = None @@ -238,7 +239,7 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False, _container.convert_to_required_dtype() # 5. Set the quantization config - quantizer = GroupQuantizer(q_int8=quantize) + quantizer = GroupQuantizer(q_int8=quantize, num_groups=quantize_groups) _container.set_quantization_config(quantizer) # 6. create a DS Inference config object @@ -401,7 +402,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size): if not config.replace_with_kernel_inject: replaced_module = set_lm_head(replaced_module) - quantizer = GroupQuantizer(q_int8=quantize) + quantizer = GroupQuantizer(q_int8=quantize, num_groups=quantize_groups) world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 if checkpoint_dict is not None and config.replace_with_kernel_inject: