Former-commit-id: c253c18185a29b59190f3e0ed236c2bb4c788085
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from .loader import load_model, load_tokenizer
|
||||
from .loader import load_config, load_model, load_tokenizer
|
||||
from .utils import find_all_linear_modules, load_valuehead_params
|
||||
|
||||
|
||||
__all__ = [
|
||||
"load_config",
|
||||
"load_model",
|
||||
"load_tokenizer",
|
||||
"load_valuehead_params",
|
||||
|
||||
@@ -12,7 +12,7 @@ from .utils import load_valuehead_params, register_autoclass
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
@@ -21,6 +21,11 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
r"""
|
||||
Gets arguments to load config/tokenizer/model.
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
model_args.model_name_or_path = try_download_model_from_ms(model_args)
|
||||
return {
|
||||
"trust_remote_code": True,
|
||||
@@ -32,9 +37,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
|
||||
def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
|
||||
r"""
|
||||
Loads pretrained tokenizer. Must before load_model.
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
Loads pretrained tokenizer.
|
||||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
try:
|
||||
@@ -57,6 +60,14 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
|
||||
return tokenizer
|
||||
|
||||
|
||||
def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
||||
r"""
|
||||
Loads model config.
|
||||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
|
||||
|
||||
def load_model(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
@@ -65,10 +76,10 @@ def load_model(
|
||||
add_valuehead: bool = False,
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Loads pretrained model. Must after load_tokenizer.
|
||||
Loads pretrained model.
|
||||
"""
|
||||
init_kwargs = _get_init_kwargs(model_args)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
config = load_config(model_args)
|
||||
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
|
||||
|
||||
model = None
|
||||
|
||||
Reference in New Issue
Block a user