@@ -10,6 +10,7 @@ from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_unsloth_available
|
||||
from ..extras.misc import check_dependencies
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
@@ -20,6 +21,9 @@ from .model_args import ModelArguments
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
check_dependencies()
|
||||
|
||||
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
@@ -221,7 +225,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1),
|
||||
training_args.parallel_mode.value == "distributed",
|
||||
str(model_args.compute_dtype),
|
||||
)
|
||||
)
|
||||
@@ -236,6 +240,8 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
|
||||
_set_transformers_logging()
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
model_args.aqlm_optimization = False
|
||||
model_args.device_map = "auto"
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
@@ -249,6 +255,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
_set_transformers_logging()
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
model_args.aqlm_optimization = True
|
||||
model_args.device_map = "auto"
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
Reference in New Issue
Block a user