update training resuming
Former-commit-id: 2ec75c31f609e65116ac3b621eeb7d8ccbf69135
This commit is contained in:
@@ -5,6 +5,7 @@ import datasets
|
||||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import (
|
||||
@@ -97,30 +98,33 @@ def get_train_args(
|
||||
if general_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("RM and PPO training can only be performed with the LoRA method.")
|
||||
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
||||
|
||||
if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||
|
||||
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
||||
raise ValueError("PPO and DPO stage can only be performed at training.")
|
||||
raise ValueError("PPO and DPO stages can only be performed at training.")
|
||||
|
||||
if general_args.stage == "ppo" and model_args.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
if general_args.stage == "ppo" and data_args.streaming:
|
||||
raise ValueError("Streaming mode does not suppport PPO training currently.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
|
||||
raise ValueError("Streaming mode should have an integer val size.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
@@ -134,9 +138,15 @@ def get_train_args(
|
||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16):
|
||||
logger.warning("We recommend enable fp16 mixed precision training.")
|
||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||
logger.warning("We recommend enable mixed precision training.")
|
||||
|
||||
# postprocess data_args
|
||||
if data_args.max_samples is not None and data_args.streaming:
|
||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
||||
data_args.max_samples = None
|
||||
|
||||
# postprocess training_args
|
||||
if (
|
||||
training_args.local_rank != -1
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
@@ -145,12 +155,26 @@ def get_train_args(
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
if data_args.max_samples is not None and data_args.streaming:
|
||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
||||
data_args.max_samples = None
|
||||
if training_args.optim == "adamw_hf":
|
||||
training_args.optim = "adamw_torch" # suppress warning
|
||||
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
if (
|
||||
training_args.resume_from_checkpoint is None
|
||||
and training_args.do_train
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args.resume_from_checkpoint = last_checkpoint
|
||||
logger.info(
|
||||
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
|
||||
)
|
||||
|
||||
# postprocess model_args
|
||||
if training_args.bf16:
|
||||
if not torch.cuda.is_bf16_supported():
|
||||
raise ValueError("Current device does not support bf16 training.")
|
||||
|
||||
Reference in New Issue
Block a user