fix saving custom code

Former-commit-id: 3f8f40bffd4f61fcc045f5f8a07420f3b46d0f7a
This commit is contained in:
hiyouga
2023-07-16 18:04:41 +08:00
parent c61de6f669
commit e9736b2ba0
2 changed files with 89 additions and 24 deletions

View File

@@ -11,7 +11,7 @@ from transformers import (
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger
@@ -36,7 +36,7 @@ def load_model_and_tokenizer(
finetuning_args: FinetuningArguments,
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
r"""
Loads pretrained model and tokenizer.
@@ -113,12 +113,12 @@ def load_model_and_tokenizer(
)
# Register auto class to save the custom code files.
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map and isinstance(config, PretrainedConfig):
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map and isinstance(tokenizer, PreTrainedTokenizer):
tokenizer.__class__.register_for_auto_class()
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map and isinstance(model, PreTrainedModel):
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()
# Initialize adapters
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model