diff --git a/src/sparseml/transformers/sparsification/sparse_model.py b/src/sparseml/transformers/sparsification/sparse_model.py index c5e17764874..31feeeef88c 100644 --- a/src/sparseml/transformers/sparsification/sparse_model.py +++ b/src/sparseml/transformers/sparsification/sparse_model.py @@ -101,16 +101,29 @@ def skip(*args, **kwargs): ) # instantiate compressor from model config - compressor = ModelCompressor.from_pretrained(pretrained_model_name_or_path) + compressor = ModelCompressor.from_pretrained( + pretrained_model_name_or_path, **kwargs + ) # temporarily set the log level to error, to ignore printing out long missing # and unexpected key error messages (these are EXPECTED for quantized models) logger = logging.getLogger("transformers.modeling_utils") restore_log_level = logger.getEffectiveLevel() logger.setLevel(level=logging.ERROR) + + if kwargs.get("trust_remote_code"): + # By artifically aliasing + # class name SparseAutoModelForCausallLM to + # AutoModelForCausalLM we can "trick" the + # `from_pretrained` method into properly + # resolving the logic when + # (has_remote_code and trust_remote_code) == True + cls.__name__ = AutoModelForCausalLM.__name__ + model = super(AutoModelForCausalLM, cls).from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs ) + if model.dtype != model.config.torch_dtype: _LOGGER.warning( f"The dtype of the loaded model: {model.dtype} is different "