diff --git a/outpostkit/repository/_loaders/__init__.py b/outpostkit/repository/_loaders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/outpostkit/repository/_loaders/transformers/__init__.py b/outpostkit/repository/_loaders/transformers/__init__.py new file mode 100644 index 0000000..bc5a89c --- /dev/null +++ b/outpostkit/repository/_loaders/transformers/__init__.py @@ -0,0 +1,484 @@ +import copy +import json +import os +from typing import Optional + +from outpostkit._utils.import_utils import is_peft_available, is_transformers_available +from outpostkit.logger import init_outpost_logger +from outpostkit.repository._loaders.transformers.constants import ( + FLAX_WEIGHTS_NAME, + PT_WEIGHTS_INDEX_NAME, + PT_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + TF2_WEIGHTS_NAME, + TF_WEIGHTS_NAME, +) +from outpostkit.repository._loaders.transformers.peft import find_adapter_config_file + +logger = init_outpost_logger(__name__) + +if is_transformers_available: + from transformers import AutoConfig, PretrainedConfig + + +# MODEL_CARD_NAME = "modelcard.json" +def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + + return weights_name + + +# ref: https://github.com/huggingface/transformers/blob/a5e5c92aea1e99cb84d7342bd63826ca6cd884c4/src/transformers/models/auto/auto_factory.py#L445 +def setup_model_for_transformers( + full_name_or_dir: str, store_dir: str, *model_args, **kwargs +): + use_safetensors: bool = kwargs.pop("use_safetensors", None) + config = kwargs.pop("config", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + kwargs["_from_auto"] = True + + hub_kwargs_names = [ + # "cache_dir", + # "force_download", + # "local_files_only", + # "proxies", + # "resume_download", + "revision", + "subfolder", + # "use_auth_token", + "token", + ] + + hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} + code_revision = kwargs.pop("code_revision", None) + adapter_kwargs = kwargs.pop("adapter_kwargs", None) + token = hub_kwargs.pop("token", None) + revision = str(kwargs.get("revision")) + if token is not None: + hub_kwargs["token"] = token + + # if resolved is None: + # if not isinstance(config, PretrainedConfig): + # # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible + # resolved_config_file = get_file( + # full_name_or_dir=full_name_or_dir, + # repo_type="model", + # file_path=CONFIG_NAME, + # **hub_kwargs, + # ) + # else: + # commit_hash = getattr(config, "_commit_hash", None) + + if is_peft_available(): + if adapter_kwargs is None: + adapter_kwargs = {} + if token is not None: + adapter_kwargs["token"] = token + + maybe_adapter_path = find_adapter_config_file( + full_name_or_dir, + ref=revision, + **adapter_kwargs, + ) + + if maybe_adapter_path is not None: + with open(maybe_adapter_path, encoding="utf-8") as f: + adapter_config = json.load(f) + + adapter_kwargs["_adapter_model_path"] = full_name_or_dir + pretrained_model_name_or_path = adapter_config[ + "base_model_name_or_path" + ] + + if not isinstance(config, PretrainedConfig): + kwargs_orig = copy.deepcopy(kwargs) + # ensure not to pollute the config object with torch_dtype="auto" - since it's + # meaningless in the context of the config object - torch.dtype values are acceptable + if kwargs.get("torch_dtype", None) == "auto": + _ = kwargs.pop("torch_dtype") + # to not overwrite the quantization_config if config has a quantization_config + if kwargs.get("quantization_config", None) is not None: + _ = kwargs.pop("quantization_config") + + config, kwargs = AutoConfig.from_pretrained( + pretrained_model_name_or_path, + return_unused_kwargs=True, + trust_remote_code=trust_remote_code, + code_revision=code_revision, + **hub_kwargs, + **kwargs, + ) + + # if torch_dtype=auto was passed here, ensure to pass it on + if kwargs_orig.get("torch_dtype", None) == "auto": + kwargs["torch_dtype"] = "auto" + if kwargs_orig.get("quantization_config", None) is not None: + kwargs["quantization_config"] = kwargs_orig["quantization_config"] + + has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map + has_local_code = type(config) in cls._model_mapping.keys() + + from_tf = kwargs.pop("from_tf", False) + from_flax = kwargs.pop("from_flax", False) + variant = kwargs.pop("variant", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + variant = kwargs.pop("variant", None) + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if from_tf and os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index" + ) + ): + # Load from a TF 1.0 checkpoint in priority if from_tf + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index" + ) + elif from_tf and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) + ): + # Load from a TF 2.0 checkpoint in priority if from_tf + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME + ) + elif from_flax and os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME + ) + ): + # Load from a Flax checkpoint in priority if from_flax + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME + ) + elif use_safetensors is not False and os.path.isfile( + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_NAME, variant), + ) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_NAME, variant), + ) + elif use_safetensors is not False and os.path.isfile( + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + ) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + ) + is_sharded = True + elif os.path.isfile( + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(PT_WEIGHTS_NAME, variant), + ) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(PT_WEIGHTS_NAME, variant), + ) + elif os.path.isfile( + os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(PT_WEIGHTS_INDEX_NAME, variant), + ) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, + subfolder, + _add_variant(PT_WEIGHTS_INDEX_NAME, variant), + ) + is_sharded = True + # At this stage we don't have a weight file so we will raise an error. + elif os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index" + ) + ) or os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) + ): + raise OSError( + f"Error no file named {_add_variant(PT_WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use" + " `from_tf=True` to load this model from those weights." + ) + elif os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME + ) + ): + raise OSError( + f"Error no file named {_add_variant(PT_WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`" + " to load this model from those weights." + ) + elif use_safetensors: + raise OSError( + f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path}." + ) + else: + raise OSError( + f"Error no file named {_add_variant(PT_WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}," + f" {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory" + f" {pretrained_model_name_or_path}." + ) + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif os.path.isfile( + os.path.join(subfolder, pretrained_model_name_or_path + ".index") + ): + if not from_tf: + raise ValueError( + f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set " + "from_tf to True to load from this checkpoint." + ) + archive_file = os.path.join( + subfolder, pretrained_model_name_or_path + ".index" + ) + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + # set correct filename + if from_tf: + filename = TF2_WEIGHTS_NAME + elif from_flax: + filename = FLAX_WEIGHTS_NAME + elif use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(PT_WEIGHTS_NAME, variant) + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant( + SAFE_WEIGHTS_NAME, variant + ): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + elif use_safetensors: + if revision == "main": + ( + resolved_archive_file, + revision, + is_sharded, + ) = auto_conversion( + pretrained_model_name_or_path, **cached_file_kwargs + ) + cached_file_kwargs["revision"] = revision + if resolved_archive_file is None: + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} " + "and thus cannot be loaded with `safetensors`. Please make sure that the model has " + "been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + filename, + **cached_file_kwargs, + ) + if resolved_archive_file is None and filename == _add_variant( + WEIGHTS_NAME, variant + ): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + + if resolved_archive_file is not None: + if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]: + # If the PyTorch file was found, check if there is a safetensors file on the repository + # If there is no safetensors file on the repositories, start an auto conversion + safe_weights_name = ( + SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME + ) + has_file_kwargs = { + "revision": revision, + "token": token, + } + cached_file_kwargs = { + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + **has_file_kwargs, + } + if not has_file( + pretrained_model_name_or_path, + safe_weights_name, + **has_file_kwargs, + ): + Thread( + target=auto_conversion, + args=(pretrained_model_name_or_path,), + kwargs={ + "ignore_errors_during_conversion": True, + **cached_file_kwargs, + }, + name="Thread-autoconversion", + ).start() + else: + # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. + # We try those to give a helpful error message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + } + if has_file( + pretrained_model_name_or_path, + TF2_WEIGHTS_NAME, + **has_file_kwargs, + ): + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(PT_WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights." + " Use `from_tf=True` to load this model from those weights." + ) + elif has_file( + pretrained_model_name_or_path, + FLAX_WEIGHTS_NAME, + **has_file_kwargs, + ): + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(PT_WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use" + " `from_flax=True` to load this model from those weights." + ) + elif variant is not None and has_file( + pretrained_model_name_or_path, + PT_WEIGHTS_NAME, + **has_file_kwargs, + ): + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(PT_WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(PT_WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or" + f" {FLAX_WEIGHTS_NAME}." + ) + except OSError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception as e: + # For any other exception, we throw a generic error. + raise OSError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(PT_WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." + ) from e + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info( + f"loading weights file {filename} from cache at {resolved_archive_file}" + ) + else: + resolved_archive_file = None + + # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. + if is_sharded: + # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + + if ( + is_safetensors_available() + and isinstance(resolved_archive_file, str) + and resolved_archive_file.endswith(".safetensors") + ): + with safe_open(resolved_archive_file, framework="pt") as f: + metadata = f.metadata() + + if metadata.get("format") == "pt": + pass + elif metadata.get("format") == "tf": + from_tf = True + logger.info( + "A TensorFlow safetensors file is being loaded in a PyTorch model." + ) + elif metadata.get("format") == "flax": + from_flax = True + logger.info("A Flax safetensors file is being loaded in a PyTorch model.") + elif metadata.get("format") == "mlx": + # This is a mlx file, we assume weights are compatible with pt + pass + else: + raise ValueError( + f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}" + ) + + from_pt = not (from_tf | from_flax) diff --git a/outpostkit/repository/_loaders/transformers/constants.py b/outpostkit/repository/_loaders/transformers/constants.py new file mode 100644 index 0000000..4707c39 --- /dev/null +++ b/outpostkit/repository/_loaders/transformers/constants.py @@ -0,0 +1,14 @@ +PT_WEIGHTS_NAME = "pytorch_model.bin" +PT_WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" +TF2_WEIGHTS_NAME = "tf_model.h5" +TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json" +TF_WEIGHTS_NAME = "model.ckpt" +FLAX_WEIGHTS_NAME = "flax_model.msgpack" +FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json" +SAFE_WEIGHTS_NAME = "model.safetensors" +SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" +CONFIG_NAME = "config.json" +FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" +IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME +PROCESSOR_NAME = "processor_config.json" +GENERATION_CONFIG_NAME = "generation_config.json" diff --git a/outpostkit/repository/_loaders/transformers/download.py b/outpostkit/repository/_loaders/transformers/download.py new file mode 100644 index 0000000..72697cd --- /dev/null +++ b/outpostkit/repository/_loaders/transformers/download.py @@ -0,0 +1,83 @@ +import os +from typing import Optional + +from outpostkit._types.repository import REPOSITORY_TYPES +from outpostkit._utils import save_file_at_path_from_response, split_full_name +from outpostkit.client import Client +from outpostkit.repository import RepositoryAtRef + + +def load_local_file_if_present(file_path: str): + if os.path.isfile(file_path): + with open(file_path) as file: + # Perform operations to load the file + # For example, you can read its contents: + file_contents = file.read() + return file_contents + else: + raise FileNotFoundError(f"The file '{file_path}' does not exist.") + + +def is_file_present_locally(file_path: str): + if not os.path.isfile(file_path): + raise FileNotFoundError(f"The file '{file_path}' does not exist.") + + +def download_file_from_repo( + repo_type: REPOSITORY_TYPES, + full_name: str, + file_path: str, + store_dir: str, + client: Optional[Client], + ref: str = "HEAD", +): + try: + (repo_entity, repo_name) = split_full_name(full_name) + except ValueError: + raise FileNotFoundError( + f"Invalid {repo_type} repository fullName or path {full_name}" + ) from None + + if client is None: + client = Client() + repo = RepositoryAtRef( + entity=repo_entity, + name=repo_name, + ref=ref, + repo_type=repo_type, + client=client, + ) + get_file_resp = repo.download_blob(file_path, raw=True) + file_loc = os.path.join(store_dir, file_path) + save_file_at_path_from_response(get_file_resp, file_loc) + return file_loc + + +def get_file( + full_name_or_dir: str, + repo_type: REPOSITORY_TYPES, + file_path: str, + store_dir: str, + ref: str = "HEAD", + token: Optional[str] = None, + client: Optional[Client] = None, + **kwargs, +) -> str: + subfolder = kwargs.pop("subfolder") + if subfolder is not None: + file_path = os.path.join(subfolder, file_path) + if token and not Client: + client = Client(api_token=token) + if os.path.isdir(full_name_or_dir): + file_loc = os.path.join(full_name_or_dir, file_path) + is_file_present_locally(file_loc) + return file_loc + else: + return download_file_from_repo( + repo_type=repo_type, + store_dir=store_dir, + ref=ref, + client=client, + file_path=file_path, + full_name=full_name_or_dir, + ) diff --git a/outpostkit/repository/_loaders/transformers/peft.py b/outpostkit/repository/_loaders/transformers/peft.py new file mode 100644 index 0000000..9c89512 --- /dev/null +++ b/outpostkit/repository/_loaders/transformers/peft.py @@ -0,0 +1,42 @@ +from typing import Optional + +from outpostkit.client import Client +from outpostkit.exceptions import OutpostHTTPException +from outpostkit.logger import init_outpost_logger +from outpostkit.repository._loaders.transformers.download import get_file + +ADAPTER_CONFIG_NAME = "adapter_config.json" +ADAPTER_WEIGHTS_NAME = "adapter_model.bin" +ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors" + +logger = init_outpost_logger(__name__) + + +def find_adapter_config_file( + full_name_or_dir: str, + store_dir: str, + ref: str = "HEAD", + token: Optional[str] = None, + client: Optional[Client] = None, + **kwargs, +) -> Optional[str]: + adapter_cached_filename = None + try: + adapter_cached_filename = get_file( + full_name_or_dir=full_name_or_dir, + file_path=ADAPTER_CONFIG_NAME, + repo_type="model", + store_dir=store_dir, + ref=ref, + token=token, + client=client, + **kwargs, + ) + except FileNotFoundError: + pass + except OutpostHTTPException as e: + if e.code == 404: + logger.warn("Could not find PEFT config file. continuing...") + else: + raise e + return adapter_cached_filename diff --git a/outpostkit/repository/_loaders/transformers/raw.py b/outpostkit/repository/_loaders/transformers/raw.py new file mode 100644 index 0000000..e8a06fd --- /dev/null +++ b/outpostkit/repository/_loaders/transformers/raw.py @@ -0,0 +1,387 @@ +def setup_model_for_transformers( + full_name_or_dir: str, store_dir: str, *model_args, **kwargs +): + if model_kwargs is None: + model_kwargs = {} + # Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs, + # this is to keep BC). + use_auth_token = model_kwargs.pop("use_auth_token", None) + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") + token = use_auth_token + + code_revision = kwargs.pop("code_revision", None) + commit_hash = kwargs.pop("_commit_hash", None) + + hub_kwargs = { + "revision": revision, + "token": token, + "trust_remote_code": trust_remote_code, + "_commit_hash": commit_hash, + } + + if task is None and model is None: + raise RuntimeError( + "Impossible to instantiate a pipeline without either a task or a model " + "being specified. " + "Please provide a task class or a model" + ) + + if model is None and tokenizer is not None: + raise RuntimeError( + "Impossible to instantiate a pipeline with tokenizer specified but not the model as the provided tokenizer" + " may not be compatible with the default model. Please provide a PreTrainedModel class or a" + " path/identifier to a pretrained model when providing tokenizer." + ) + if model is None and feature_extractor is not None: + raise RuntimeError( + "Impossible to instantiate a pipeline with feature_extractor specified but not the model as the provided" + " feature_extractor may not be compatible with the default model. Please provide a PreTrainedModel class" + " or a path/identifier to a pretrained model when providing feature_extractor." + ) + if isinstance(model, Path): + model = str(model) + + if commit_hash is None: + pretrained_model_name_or_path = None + if isinstance(config, str): + pretrained_model_name_or_path = config + elif config is None and isinstance(model, str): + pretrained_model_name_or_path = model + + if not isinstance(config, PretrainedConfig) and pretrained_model_name_or_path is not None: + # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible + resolved_config_file = cached_file( + pretrained_model_name_or_path, + CONFIG_NAME, + _raise_exceptions_for_gated_repo=False, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + **hub_kwargs, + ) + hub_kwargs["_commit_hash"] = extract_commit_hash(resolved_config_file, commit_hash) + else: + hub_kwargs["_commit_hash"] = getattr(config, "_commit_hash", None) + + # Config is the primordial information item. + # Instantiate config if needed + if isinstance(config, str): + config = AutoConfig.from_pretrained( + config, _from_pipeline=task, code_revision=code_revision, **hub_kwargs, **model_kwargs + ) + hub_kwargs["_commit_hash"] = config._commit_hash + elif config is None and isinstance(model, str): + # Check for an adapter file in the model path if PEFT is available + if is_peft_available(): + # `find_adapter_config_file` doesn't accept `trust_remote_code` + _hub_kwargs = {k: v for k, v in hub_kwargs.items() if k != "trust_remote_code"} + maybe_adapter_path = find_adapter_config_file( + model, + token=hub_kwargs["token"], + revision=hub_kwargs["revision"], + _commit_hash=hub_kwargs["_commit_hash"], + ) + + if maybe_adapter_path is not None: + with open(maybe_adapter_path, encoding="utf-8") as f: + adapter_config = json.load(f) + model = adapter_config["base_model_name_or_path"] + + config = AutoConfig.from_pretrained( + model, _from_pipeline=task, code_revision=code_revision, **hub_kwargs, **model_kwargs + ) + hub_kwargs["_commit_hash"] = config._commit_hash + + custom_tasks = {} + if config is not None and len(getattr(config, "custom_pipelines", {})) > 0: + custom_tasks = config.custom_pipelines + if task is None and trust_remote_code is not False: + if len(custom_tasks) == 1: + task = list(custom_tasks.keys())[0] + else: + raise RuntimeError( + "We can't infer the task automatically for this model as there are multiple tasks available. Pick " + f"one in {', '.join(custom_tasks.keys())}" + ) + + if task is None and model is not None: + if not isinstance(model, str): + raise RuntimeError( + "Inferring the task automatically requires to check the hub with a model_id defined as a `str`. " + f"{model} is not a valid model_id." + ) + task = get_task(model, token) + + # Retrieve the task + if task in custom_tasks: + normalized_task = task + targeted_task, task_options = clean_custom_task(custom_tasks[task]) + if pipeline_class is None: + if not trust_remote_code: + raise ValueError( + "Loading this pipeline requires you to execute the code in the pipeline file in that" + " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" + " set the option `trust_remote_code=True` to remove this error." + ) + class_ref = targeted_task["impl"] + pipeline_class = get_class_from_dynamic_module( + class_ref, + model, + code_revision=code_revision, + **hub_kwargs, + ) + else: + normalized_task, targeted_task, task_options = check_task(task) + if pipeline_class is None: + pipeline_class = targeted_task["impl"] + + # Use default model/config/tokenizer for the task if no model is provided + if model is None: + # At that point framework might still be undetermined + model, default_revision = get_default_model_and_revision(targeted_task, framework, task_options) + revision = revision if revision is not None else default_revision + logger.warning( + f"No model was supplied, defaulted to {model} and revision" + f" {revision} ({HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n" + "Using a pipeline without specifying a model name and revision in production is not recommended." + ) + if config is None and isinstance(model, str): + config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs) + hub_kwargs["_commit_hash"] = config._commit_hash + + if device_map is not None: + if "device_map" in model_kwargs: + raise ValueError( + 'You cannot use both `pipeline(... device_map=..., model_kwargs={"device_map":...})` as those' + " arguments might conflict, use only one.)" + ) + if device is not None: + logger.warning( + "Both `device` and `device_map` are specified. `device` will override `device_map`. You" + " will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`." + ) + model_kwargs["device_map"] = device_map + if torch_dtype is not None: + if "torch_dtype" in model_kwargs: + raise ValueError( + 'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those' + " arguments might conflict, use only one.)" + ) + if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype): + torch_dtype = getattr(torch, torch_dtype) + model_kwargs["torch_dtype"] = torch_dtype + + model_name = model if isinstance(model, str) else None + + # Load the correct model if possible + # Infer the framework from the model if not already defined + if isinstance(model, str) or framework is None: + model_classes = {"tf": targeted_task["tf"], "pt": targeted_task["pt"]} + framework, model = infer_framework_load_model( + model, + model_classes=model_classes, + config=config, + framework=framework, + task=task, + **hub_kwargs, + **model_kwargs, + ) + + model_config = model.config + hub_kwargs["_commit_hash"] = model.config._commit_hash + load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None + load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None + load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None + + # If `model` (instance of `PretrainedModel` instead of `str`) is passed (and/or same for config), while + # `image_processor` or `feature_extractor` is `None`, the loading will fail. This happens particularly for some + # vision tasks when calling `pipeline()` with `model` and only one of the `image_processor` and `feature_extractor`. + # TODO: we need to make `NO_IMAGE_PROCESSOR_TASKS` and `NO_FEATURE_EXTRACTOR_TASKS` more robust to avoid such issue. + # This block is only temporarily to make CI green. + if load_image_processor and load_feature_extractor: + load_feature_extractor = False + + if ( + tokenizer is None + and not load_tokenizer + and normalized_task not in NO_TOKENIZER_TASKS + # Using class name to avoid importing the real class. + and ( + model_config.__class__.__name__ in MULTI_MODEL_AUDIO_CONFIGS + or model_config.__class__.__name__ in MULTI_MODEL_VISION_CONFIGS + ) + ): + # This is a special category of models, that are fusions of multiple models + # so the model_config might not define a tokenizer, but it seems to be + # necessary for the task, so we're force-trying to load it. + load_tokenizer = True + if ( + image_processor is None + and not load_image_processor + and normalized_task not in NO_IMAGE_PROCESSOR_TASKS + # Using class name to avoid importing the real class. + and model_config.__class__.__name__ in MULTI_MODEL_VISION_CONFIGS + ): + # This is a special category of models, that are fusions of multiple models + # so the model_config might not define a tokenizer, but it seems to be + # necessary for the task, so we're force-trying to load it. + load_image_processor = True + if ( + feature_extractor is None + and not load_feature_extractor + and normalized_task not in NO_FEATURE_EXTRACTOR_TASKS + # Using class name to avoid importing the real class. + and model_config.__class__.__name__ in MULTI_MODEL_AUDIO_CONFIGS + ): + # This is a special category of models, that are fusions of multiple models + # so the model_config might not define a tokenizer, but it seems to be + # necessary for the task, so we're force-trying to load it. + load_feature_extractor = True + + if task in NO_TOKENIZER_TASKS: + # These will never require a tokenizer. + # the model on the other hand might have a tokenizer, but + # the files could be missing from the hub, instead of failing + # on such repos, we just force to not load it. + load_tokenizer = False + + if task in NO_FEATURE_EXTRACTOR_TASKS: + load_feature_extractor = False + if task in NO_IMAGE_PROCESSOR_TASKS: + load_image_processor = False + + if load_tokenizer: + # Try to infer tokenizer from model or config name (if provided as str) + if tokenizer is None: + if isinstance(model_name, str): + tokenizer = model_name + elif isinstance(config, str): + tokenizer = config + else: + # Impossible to guess what is the right tokenizer here + raise Exception( + "Impossible to guess which tokenizer to use. " + "Please provide a PreTrainedTokenizer class or a path/identifier to a pretrained tokenizer." + ) + + # Instantiate tokenizer if needed + if isinstance(tokenizer, (str, tuple)): + if isinstance(tokenizer, tuple): + # For tuple we have (tokenizer name, {kwargs}) + use_fast = tokenizer[1].pop("use_fast", use_fast) + tokenizer_identifier = tokenizer[0] + tokenizer_kwargs = tokenizer[1] + else: + tokenizer_identifier = tokenizer + tokenizer_kwargs = model_kwargs.copy() + tokenizer_kwargs.pop("torch_dtype", None) + + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_identifier, use_fast=use_fast, _from_pipeline=task, **hub_kwargs, **tokenizer_kwargs + ) + + if load_image_processor: + # Try to infer image processor from model or config name (if provided as str) + if image_processor is None: + if isinstance(model_name, str): + image_processor = model_name + elif isinstance(config, str): + image_processor = config + # Backward compatibility, as `feature_extractor` used to be the name + # for `ImageProcessor`. + elif feature_extractor is not None and isinstance(feature_extractor, BaseImageProcessor): + image_processor = feature_extractor + else: + # Impossible to guess what is the right image_processor here + raise Exception( + "Impossible to guess which image processor to use. " + "Please provide a PreTrainedImageProcessor class or a path/identifier " + "to a pretrained image processor." + ) + + # Instantiate image_processor if needed + if isinstance(image_processor, (str, tuple)): + image_processor = AutoImageProcessor.from_pretrained( + image_processor, _from_pipeline=task, **hub_kwargs, **model_kwargs + ) + + if load_feature_extractor: + # Try to infer feature extractor from model or config name (if provided as str) + if feature_extractor is None: + if isinstance(model_name, str): + feature_extractor = model_name + elif isinstance(config, str): + feature_extractor = config + else: + # Impossible to guess what is the right feature_extractor here + raise Exception( + "Impossible to guess which feature extractor to use. " + "Please provide a PreTrainedFeatureExtractor class or a path/identifier " + "to a pretrained feature extractor." + ) + + # Instantiate feature_extractor if needed + if isinstance(feature_extractor, (str, tuple)): + feature_extractor = AutoFeatureExtractor.from_pretrained( + feature_extractor, _from_pipeline=task, **hub_kwargs, **model_kwargs + ) + + if ( + feature_extractor._processor_class + and feature_extractor._processor_class.endswith("WithLM") + and isinstance(model_name, str) + ): + try: + import kenlm # to trigger `ImportError` if not installed + from pyctcdecode import BeamSearchDecoderCTC + + if os.path.isdir(model_name) or os.path.isfile(model_name): + decoder = BeamSearchDecoderCTC.load_from_dir(model_name) + else: + language_model_glob = os.path.join( + BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*" + ) + alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME + allow_patterns = [language_model_glob, alphabet_filename] + decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_patterns=allow_patterns) + + kwargs["decoder"] = decoder + except ImportError as e: + logger.warning(f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Error: {e}") + if not is_kenlm_available(): + logger.warning("Try to install `kenlm`: `pip install kenlm") + + if not is_pyctcdecode_available(): + logger.warning("Try to install `pyctcdecode`: `pip install pyctcdecode") + + if task == "translation" and model.config.task_specific_params: + for key in model.config.task_specific_params: + if key.startswith("translation"): + task = key + warnings.warn( + f'"translation" task was used, instead of "translation_XX_to_YY", defaulting to "{task}"', + UserWarning, + ) + break + + if tokenizer is not None: + kwargs["tokenizer"] = tokenizer + + if feature_extractor is not None: + kwargs["feature_extractor"] = feature_extractor + + if torch_dtype is not None: + kwargs["torch_dtype"] = torch_dtype + + if image_processor is not None: + kwargs["image_processor"] = image_processor + + if device is not None: + kwargs["device"] = device + + return pipeline_class(model=model, framework=framework, task=task, **kwargs) diff --git a/outpostkit/repository/_loaders/transformers/utils.py b/outpostkit/repository/_loaders/transformers/utils.py new file mode 100644 index 0000000..e69de29 diff --git a/outpostkit/repository/download.py b/outpostkit/repository/download.py new file mode 100644 index 0000000..5c29720 --- /dev/null +++ b/outpostkit/repository/download.py @@ -0,0 +1,5 @@ +from outpostkit.repository import Repository + + +def download_file_from_repo(full_name:str, filepath:str): + repo = Repository diff --git a/pyproject.toml b/pyproject.toml index ebdd0d4..674d650 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "packaging", "pydantic>1", "typing_extensions>=4.5.0", + "dataclasses_json", ] optional-dependencies = { dev = [ "pylint",