Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend DeepSpeed inference initialization API with a 'quantize_groups' argument #3519

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .runtime.hybrid_engine import DeepSpeedHybridEngine
from .runtime.pipe.engine import PipelineEngine
from .inference.engine import InferenceEngine
from .inference.config import DeepSpeedInferenceConfig
from .inference.config import DeepSpeedInferenceConfig, QuantizationConfig
from .runtime.lr_schedules import add_tuning_arguments
from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
from .runtime.activation_checkpointing import checkpointing
Expand Down Expand Up @@ -337,6 +337,14 @@ def init_inference(model, config=None, **kwargs):
raise ValueError(f"Conflicting argument '{key}' in 'config':{config_dict[key]} and kwargs:{kwargs[key]}")
config_dict.update(kwargs)

# Set the number of weight quantization groups if an optional 'quantize_groups' argument is given
if "quantize_groups" in config_dict:
if not ("dtype", torch.int8) in config_dict.items():
raise ValueError("'dtype' argument expected int8 when 'quantize_groups' argument is provided")
quant = QuantizationConfig()
quant.weight.q_groups = config_dict.pop("quantize_groups")
config_dict["quant"] = quant

Comment on lines +362 to +369
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you are adding quantize_groups as a shortcut for the admittedly convoluted current quantize config settings? For example, with this you could just pass quantize_groups=2 rather than quant={"weight":{"q_groups":2}}.

But perhaps we should look into how we can simplify the quantize settings or at the very least add this logic to the DeepSpeedInferenceConfig class as a pydantic validator (so that the config logic is consolidated there).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I suppose one could use a config file with quant={"weight":{"q_groups":2}} instead. What I am suggesting here is a simple way to control this setting from a command line. I do agree that there might be better ways of achieving that than special-casing this argument in init_inference

ds_inference_config = DeepSpeedInferenceConfig(**config_dict)

engine = InferenceEngine(model, config=ds_inference_config)
Expand Down
5 changes: 3 additions & 2 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m
"""
# defining globals as internally defined functions inherit these everywhere
quantize = (config.dtype == torch.int8)
quantize_groups = config.quant.weight.q_groups if quantize else 0
# todo: Refactor later. In future, let's minimize the style used above and use config.** instead

linear_layer_setting = None
Expand Down Expand Up @@ -237,7 +238,7 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False,
_container.convert_to_required_dtype()

# 5. Set the quantization config
quantizer = GroupQuantizer(q_int8=quantize)
quantizer = GroupQuantizer(q_int8=quantize, num_groups=quantize_groups)
_container.set_quantization_config(quantizer)

# 6. create a DS Inference config object
Expand Down Expand Up @@ -341,7 +342,7 @@ def set_lm_head(module):
replace_fn=replace_fn,
_replace_policy=config.injection_policy_tuple)

quantizer = GroupQuantizer(q_int8=quantize)
quantizer = GroupQuantizer(q_int8=quantize, num_groups=quantize_groups)
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
if checkpoint_dict is not None and config.replace_with_kernel_inject:
Expand Down