[deps] upgrade transformers (#8159)
This commit is contained in:
@@ -95,7 +95,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version(
|
||||
"transformers>=4.45.0,<=4.52.1,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
|
||||
"transformers>=4.45.0,<=4.52.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
|
||||
)
|
||||
check_version("datasets>=2.16.0,<=3.6.0")
|
||||
check_version("accelerate>=0.34.0,<=1.7.0")
|
||||
|
||||
@@ -163,7 +163,7 @@ def save_args(config_path: str, config_dict: dict[str, Any]) -> None:
|
||||
|
||||
def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]:
|
||||
r"""Remove args with NoneType or False or empty string value."""
|
||||
no_skip_keys = ["packing"]
|
||||
no_skip_keys = ["packing", "freeze_vision_tower", "freeze_multi_modal_projector", "freeze_language_model"]
|
||||
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
|
||||
|
||||
|
||||
|
||||
@@ -22,14 +22,13 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.constants import LLAMABOARD_CONFIG, MULTIMODAL_SUPPORTED_MODELS, PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.misc import is_accelerator_available, torch_gc, use_ray
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import (
|
||||
DEFAULT_CACHE_DIR,
|
||||
DEFAULT_CONFIG_DIR,
|
||||
abort_process,
|
||||
calculate_pixels,
|
||||
gen_cmd,
|
||||
get_save_dir,
|
||||
load_args,
|
||||
@@ -165,13 +164,6 @@ class Runner:
|
||||
use_llama_pro=get("train.use_llama_pro"),
|
||||
enable_thinking=get("train.enable_thinking"),
|
||||
report_to=get("train.report_to"),
|
||||
freeze_vision_tower=get("train.freeze_vision_tower"),
|
||||
freeze_multi_modal_projector=get("train.freeze_multi_modal_projector"),
|
||||
freeze_language_model=get("train.freeze_language_model"),
|
||||
image_max_pixels=calculate_pixels(get("train.image_max_pixels")),
|
||||
image_min_pixels=calculate_pixels(get("train.image_min_pixels")),
|
||||
video_max_pixels=calculate_pixels(get("train.video_max_pixels")),
|
||||
video_min_pixels=calculate_pixels(get("train.video_min_pixels")),
|
||||
use_galore=get("train.use_galore"),
|
||||
use_apollo=get("train.use_apollo"),
|
||||
use_badam=get("train.use_badam"),
|
||||
@@ -244,6 +236,16 @@ class Runner:
|
||||
args["pref_ftx"] = get("train.pref_ftx")
|
||||
args["pref_loss"] = get("train.pref_loss")
|
||||
|
||||
# multimodal config
|
||||
if model_name in MULTIMODAL_SUPPORTED_MODELS:
|
||||
args["freeze_vision_tower"] = get("train.freeze_vision_tower")
|
||||
args["freeze_multi_modal_projector"] = get("train.freeze_multi_modal_projector")
|
||||
args["freeze_language_model"] = get("train.freeze_language_model")
|
||||
args["image_max_pixels"] = get("train.image_max_pixels")
|
||||
args["image_min_pixels"] = get("train.image_min_pixels")
|
||||
args["video_max_pixels"] = get("train.video_max_pixels")
|
||||
args["video_min_pixels"] = get("train.video_min_pixels")
|
||||
|
||||
# galore config
|
||||
if args["use_galore"]:
|
||||
args["galore_rank"] = get("train.galore_rank")
|
||||
|
||||
Reference in New Issue
Block a user