Skip to content

Commit

Permalink
Align with transformers the download and cache of INC config and quan…
Browse files Browse the repository at this point in the history
…tized models (#27)


Signed-off-by: Wang, Yi A <[email protected]>

Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi authored Aug 31, 2022
1 parent 03c77ae commit 1a2e42f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 43 deletions.
59 changes: 35 additions & 24 deletions optimum/intel/neural_compressor/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import logging
import os
from functools import reduce
from pathlib import Path
from typing import Any, Optional, Union

from transformers.file_utils import cached_path, hf_bucket_url
from transformers.utils import TRANSFORMERS_CACHE, is_offline_mode

import yaml
from huggingface_hub import hf_hub_download
from neural_compressor.conf.config import Conf, Distillation_Conf, Pruning_Conf, Quantization_Conf
from optimum.intel.neural_compressor.utils import CONFIG_NAME

Expand Down Expand Up @@ -97,37 +99,46 @@ def from_pretrained(cls, config_name_or_path: str, config_file_name: Optional[st
revision = kwargs.get("revision", None)

config_file_name = config_file_name if config_file_name is not None else CONFIG_NAME

if os.path.isdir(config_name_or_path):
config_file = os.path.join(config_name_or_path, config_file_name)
elif os.path.isfile(config_name_or_path):
config_file = config_name_or_path
else:
config_file = hf_bucket_url(config_name_or_path, filename=config_file_name, revision=revision)

try:
resolved_config_file = cached_path(
config_file,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
)
except EnvironmentError as err:
logger.error(err)
msg = (
f"Can't load config for '{config_name_or_path}'. Make sure that:\n\n"
f"-'{config_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"-or '{config_name_or_path}' is a correct path to a directory containing a {config_file_name} file\n\n"
)

if revision is not None:
msg += (
f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that "
f"exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
local_files_only = False
if is_offline_mode():
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
try:
config_file = hf_hub_download(
repo_id=config_name_or_path,
filename=config_file_name,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
local_files_only=local_files_only,
)
except EnvironmentError as err:
logger.error(err)
msg = (
f"Can't load config for '{config_name_or_path}'. Make sure that:\n\n"
f"-'{config_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"-or '{config_name_or_path}' is a correct path to a directory containing a {config_file_name} file\n\n"
)

raise EnvironmentError(msg)
if revision is not None:
msg += (
f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that "
f"exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
)

config = cls(resolved_config_file)
raise EnvironmentError(msg)
config = cls(config_file)

return config

Expand Down
52 changes: 33 additions & 19 deletions optimum/intel/neural_compressor/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import os
from enum import Enum
from pathlib import Path
from typing import Callable, ClassVar, Dict, Optional, Union

import torch
Expand All @@ -33,11 +34,12 @@
AutoModelForTokenClassification,
XLNetLMHeadModel,
)
from transformers.file_utils import cached_path, hf_bucket_url
from transformers.models.auto.auto_factory import _get_model_class
from transformers.utils import TRANSFORMERS_CACHE, is_offline_mode
from transformers.utils.versions import require_version

import neural_compressor
from huggingface_hub import hf_hub_download
from neural_compressor.adaptor.pytorch import PyTorch_FXAdaptor, _cfg_to_qconfig, _propagate_qconfig
from neural_compressor.adaptor.torch_utils.util import get_embedding_contiguous
from neural_compressor.conf.config import Quantization_Conf
Expand Down Expand Up @@ -276,26 +278,38 @@ def from_pretrained(
elif os.path.isfile(model_name_or_path):
state_dict_path = model_name_or_path
else:
state_dict_path = hf_bucket_url(model_name_or_path, filename=q_model_name, revision=revision)

try:
state_dict_path = cached_path(state_dict_path, **download_kwargs)
except EnvironmentError as err:
logger.error(err)
msg = (
f"Can't load config for '{model_name_or_path}'. Make sure that:\n\n - '{model_name_or_path}' is a "
f"correct model identifier listed on 'https://huggingface.co/models'\n\n - or "
f"'{model_name_or_path}' is a correct path to a directory containing a {q_model_name} file\n\n"
)

if revision is not None:
msg += (
f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) "
f"thatexists for this model name as listed on its model page on "
f"'https://huggingface.co/models'\n\n"
local_files_only = False
if is_offline_mode():
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
cache_dir = download_kwargs.get("cache_dir", None)
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
try:
state_dict_path = hf_hub_download(
repo_id=model_name_or_path,
filename=q_model_name,
revision=revision,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
except EnvironmentError as err:
logger.error(err)
msg = (
f"Can't load config for '{model_name_or_path}'. Make sure that:\n\n"
f"-'{model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
f"-or '{model_name_or_path}' is a correct path to a directory containing a {q_model_name} file\n\n"
)

if revision is not None:
msg += (
f"- or '{revision}' is a valid git identifier (branch name, a tag name, or a commit id) that "
f"exists for this model name as listed on its model page on 'https://huggingface.co/models'\n\n"
)

raise EnvironmentError(msg)
raise EnvironmentError(msg)

state_dict = torch.load(state_dict_path)

Expand Down

0 comments on commit 1a2e42f

Please sign in to comment.