From 1a2e42f9ed79d71ef2507c2f8e7160a7cdcf3175 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Thu, 1 Sep 2022 00:08:47 +0800 Subject: [PATCH] Align with transformers the download and cache of INC config and quantized models (#27) Signed-off-by: Wang, Yi A Signed-off-by: Wang, Yi A --- .../intel/neural_compressor/configuration.py | 59 +++++++++++-------- .../intel/neural_compressor/quantization.py | 52 ++++++++++------ 2 files changed, 68 insertions(+), 43 deletions(-) diff --git a/optimum/intel/neural_compressor/configuration.py b/optimum/intel/neural_compressor/configuration.py index c55b6089e0..8978abc94c 100644 --- a/optimum/intel/neural_compressor/configuration.py +++ b/optimum/intel/neural_compressor/configuration.py @@ -15,11 +15,13 @@ import logging import os from functools import reduce +from pathlib import Path from typing import Any, Optional, Union -from transformers.file_utils import cached_path, hf_bucket_url +from transformers.utils import TRANSFORMERS_CACHE, is_offline_mode import yaml +from huggingface_hub import hf_hub_download from neural_compressor.conf.config import Conf, Distillation_Conf, Pruning_Conf, Quantization_Conf from optimum.intel.neural_compressor.utils import CONFIG_NAME @@ -97,37 +99,46 @@ def from_pretrained(cls, config_name_or_path: str, config_file_name: Optional[st revision = kwargs.get("revision", None) config_file_name = config_file_name if config_file_name is not None else CONFIG_NAME + if os.path.isdir(config_name_or_path): config_file = os.path.join(config_name_or_path, config_file_name) elif os.path.isfile(config_name_or_path): config_file = config_name_or_path else: - config_file = hf_bucket_url(config_name_or_path, filename=config_file_name, revision=revision) - - try: - resolved_config_file = cached_path( - config_file, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - ) - except EnvironmentError as err: - logger.error(err) - msg = ( - f"Can't load config for '{config_name_or_path}'. Make sure that:\n\n" - f"-'{config_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" - f"-or '{config_name_or_path}' is a correct path to a directory containing a {config_file_name} file\n\n" - ) - - if revision is not None: - msg += ( - f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that " - f"exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n" + local_files_only = False + if is_offline_mode(): + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + if cache_dir is None: + cache_dir = TRANSFORMERS_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + try: + config_file = hf_hub_download( + repo_id=config_name_or_path, + filename=config_file_name, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + local_files_only=local_files_only, + ) + except EnvironmentError as err: + logger.error(err) + msg = ( + f"Can't load config for '{config_name_or_path}'. Make sure that:\n\n" + f"-'{config_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" + f"-or '{config_name_or_path}' is a correct path to a directory containing a {config_file_name} file\n\n" ) - raise EnvironmentError(msg) + if revision is not None: + msg += ( + f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that " + f"exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n" + ) - config = cls(resolved_config_file) + raise EnvironmentError(msg) + config = cls(config_file) return config diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index 976bab1656..34a7cd6d08 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -16,6 +16,7 @@ import logging import os from enum import Enum +from pathlib import Path from typing import Callable, ClassVar, Dict, Optional, Union import torch @@ -33,11 +34,12 @@ AutoModelForTokenClassification, XLNetLMHeadModel, ) -from transformers.file_utils import cached_path, hf_bucket_url from transformers.models.auto.auto_factory import _get_model_class +from transformers.utils import TRANSFORMERS_CACHE, is_offline_mode from transformers.utils.versions import require_version import neural_compressor +from huggingface_hub import hf_hub_download from neural_compressor.adaptor.pytorch import PyTorch_FXAdaptor, _cfg_to_qconfig, _propagate_qconfig from neural_compressor.adaptor.torch_utils.util import get_embedding_contiguous from neural_compressor.conf.config import Quantization_Conf @@ -276,26 +278,38 @@ def from_pretrained( elif os.path.isfile(model_name_or_path): state_dict_path = model_name_or_path else: - state_dict_path = hf_bucket_url(model_name_or_path, filename=q_model_name, revision=revision) - - try: - state_dict_path = cached_path(state_dict_path, **download_kwargs) - except EnvironmentError as err: - logger.error(err) - msg = ( - f"Can't load config for '{model_name_or_path}'. Make sure that:\n\n - '{model_name_or_path}' is a " - f"correct model identifier listed on 'https://huggingface.co/models'\n\n - or " - f"'{model_name_or_path}' is a correct path to a directory containing a {q_model_name} file\n\n" - ) - - if revision is not None: - msg += ( - f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) " - f"thatexists for this model name as listed on its model page on " - f"'https://huggingface.co/models'\n\n" + local_files_only = False + if is_offline_mode(): + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + cache_dir = download_kwargs.get("cache_dir", None) + if cache_dir is None: + cache_dir = TRANSFORMERS_CACHE + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + try: + state_dict_path = hf_hub_download( + repo_id=model_name_or_path, + filename=q_model_name, + revision=revision, + cache_dir=cache_dir, + local_files_only=local_files_only, ) + except EnvironmentError as err: + logger.error(err) + msg = ( + f"Can't load config for '{model_name_or_path}'. Make sure that:\n\n" + f"-'{model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" + f"-or '{model_name_or_path}' is a correct path to a directory containing a {q_model_name} file\n\n" + ) + + if revision is not None: + msg += ( + f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that " + f"exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n" + ) - raise EnvironmentError(msg) + raise EnvironmentError(msg) state_dict = torch.load(state_dict_path)