Skip to content

Commit

Permalink
feat(nn4k): add huggingface decode only model local sft feature (#1) (#…
Browse files Browse the repository at this point in the history
…109)

Co-authored-by: xionghuaidong <[email protected]>
  • Loading branch information
chenbin11200 and xionghuaidong authored Feb 22, 2024
1 parent e95725d commit eb2590a
Show file tree
Hide file tree
Showing 29 changed files with 1,252 additions and 206 deletions.
1 change: 1 addition & 0 deletions .licenserc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ header:
- '**/*.schema'
- '**/*.rule'
- '**/*.json'
- '**/*.json5'
- '**/*.in'
- '**/META-INF/services/*'
- '**/*.conf'
Expand Down
3 changes: 1 addition & 2 deletions python/nn4k/nn4k/consts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
NN_EXECUTOR_KEY = "nn_executor"
NN_EXECUTOR_TEXT = "NN executor"

NN_DEVICE_KEY = "device"
NN_TRUST_REMOTE_CODE_KEY = "trust_remote_code"
NN_DEVICE_KEY = "nn_device"

NN_OPENAI_MODEL_NAME_KEY = NN_NAME_KEY
NN_OPENAI_MODEL_NAME_TEXT = "openai model name"
Expand Down
2 changes: 1 addition & 1 deletion python/nn4k/nn4k/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

from nn4k.executor.base import NNExecutor, LLMExecutor
from nn4k.executor.base import NNExecutor, LLMExecutor, NNModelArgs, NNAdapterModelArgs
93 changes: 91 additions & 2 deletions python/nn4k/nn4k/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
# or implied.

from abc import ABC, abstractmethod
from typing import Union
from dataclasses import dataclass, field
from typing import Optional, Union


class NNExecutor(ABC):
Expand Down Expand Up @@ -145,7 +146,23 @@ def from_config(cls, nn_config: Union[str, dict]) -> "NNExecutor":
raise RuntimeError(message)


class LLMExecutor(NNExecutor):
class LLMExecutor(NNExecutor, ABC):
"""
Base Executor for LLM.
"""

@classmethod
def from_config(cls, nn_config: Union[str, dict]) -> "LLMExecutor":
"""
Implement distribution logic for LLM, since we only support Huggingface Decode Only models for now,
it is directly point to HFDecodeOnlyExecutor. Will use the hub management functions later on.
"""
from nn4k.executor.huggingface.hf_decode_only_executor import (
HFDecodeOnlyExecutor,
)

return HFDecodeOnlyExecutor.from_config(nn_config)

def execute_sft(self, args=None, callbacks=None, **kwargs):
"""
The entry point of SFT execution in a certain pod.
Expand All @@ -159,3 +176,75 @@ def execute_rl_tuning(self, args=None, callbacks=None, **kwargs):
raise NotImplementedError(
f"{self.__class__.__name__} does not support RL-Tuning."
)


@dataclass
class NNModelArgs:
"""
Base NN4K-supported model definition and load related args.
"""

nn_name: Optional[str] = field(
default=None,
metadata={"help": ("NN4K model name")},
)
nn_version: Optional[str] = field(
default="default",
metadata={"help": ("NN4K model version, by default is 'default'")},
)
nn_model_path: Optional[str] = field(
default=None,
metadata={
"help": (
"model path dir, could be delivered by user or get managed in Hub."
)
},
)
nn_device: Optional[str] = field(
default="auto", metadata={"help": ("device to use to load model")}
)

def __post_init__(self):
assert (
self.nn_name is not None or self.nn_model_path is not None
), "either nn_name or nn_model_path has to be provided"


@dataclass
class NNAdapterModelArgs(NNModelArgs):
"""
One should use this args dataclass to enable adapter models.
"""

adapter_name: str = field(
default=None,
metadata={
"help": "adapter name. Should be provided if you want to sft or load a adapter model."
},
)
adapter_version: str = field(
default="auto",
metadata={
"help": "adapter is designed to get managed by versions, by default is 'latest'"
},
)
adapter_type: str = field(
default="lora", metadata={"help": "adapter type, lora by default."}
)
adapter_path: str = field(
default=None,
metadata={
"help": "adapter weight and config path, could be delivered by user or get managed in Hub."
},
)
adapter_config: Optional[dict] = field(
default=None,
metadata={
"help": "Only necessary if you want to init a new adapter model and train from scratch or resume"
"from a checkpoint (in this case, should be the same as the previous adapter_config)."
"Values are the same as peft config init args."
},
)

def __post_init__(self):
super().__post_init__()
152 changes: 0 additions & 152 deletions python/nn4k/nn4k/executor/hugging_face.py

This file was deleted.

13 changes: 13 additions & 0 deletions python/nn4k/nn4k/executor/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

from nn4k.executor.huggingface.base.hf_llm_executor import HFLLMExecutor
from nn4k.executor.huggingface.base.hf_args import HFModelArgs, HFSftArgs
10 changes: 10 additions & 0 deletions python/nn4k/nn4k/executor/huggingface/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
Loading

0 comments on commit eb2590a

Please sign in to comment.