Merge branch 'main' into cpei/refactor

Former-commit-id: c2951f17f726470bcd5dff6bf7028ec90212442e
This commit is contained in:
hoshi-hiyouga
2024-10-08 17:31:17 +08:00
committed by GitHub
17 changed files with 775 additions and 240 deletions

View File

@@ -21,12 +21,12 @@ from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
from .adapter import init_adapter
from .model_utils.liger_kernel import apply_liger_kernel
from .model_utils.misc import register_autoclass
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
from .model_utils.unsloth import load_unsloth_pretrained_model
from .model_utils.valuehead import load_valuehead_params
from .model_utils.visual import get_image_seqlen
from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
if TYPE_CHECKING:
@@ -61,7 +61,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
r"""
Loads pretrained tokenizer.
Loads pretrained tokenizer and optionally loads processor.
Note: including inplace operation of model_args.
"""
@@ -96,15 +96,9 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
patch_tokenizer(tokenizer)
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution)
setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
patch_processor(processor, config, tokenizer, model_args)
except Exception as e:
logger.warning("Failed to load processor. Error: {}".format(e))
processor = None
@@ -138,6 +132,7 @@ def load_model(
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)
apply_liger_kernel(config, model_args, is_trainable, require_logits=(finetuning_args.stage not in ["pt", "sft"]))
model = None
lazy_load = False
@@ -158,7 +153,6 @@ def load_model(
load_class = AutoModelForVision2Seq
else:
load_class = AutoModelForCausalLM
if model_args.train_from_scratch:
model = load_class.from_config(config)
else: