refactor model_dtype, fix PPO trainer
Former-commit-id: 3e17ee5afbcb823a7c9a2f91864b3750cd79edb4
This commit is contained in:
@@ -8,16 +8,6 @@ from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
try:
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available, is_torch_cuda_available
|
||||
is_fp16_available = is_torch_cuda_available()
|
||||
is_bf16_available = is_torch_bf16_gpu_available()
|
||||
is_npu_available = is_torch_npu_available()
|
||||
except ImportError:
|
||||
is_fp16_available = torch.cuda.is_available()
|
||||
is_bf16_available = torch.cuda.is_bf16_supported()
|
||||
is_npu_available = False
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
@@ -31,17 +21,6 @@ from llmtuner.hparams import (
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _infer_dtype() -> torch.dtype:
|
||||
if is_npu_available:
|
||||
return torch.float16
|
||||
elif is_bf16_available:
|
||||
return torch.bfloat16
|
||||
elif is_fp16_available:
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
@@ -178,12 +157,15 @@ def get_train_args(
|
||||
if not finetuning_args.resume_lora_training:
|
||||
raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
|
||||
|
||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||
logger.warning("We recommend enable mixed precision training.")
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
# postprocess data_args
|
||||
if data_args.max_samples is not None and data_args.streaming:
|
||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
||||
@@ -206,10 +188,9 @@ def get_train_args(
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
|
||||
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args_dict = training_args.to_dict()
|
||||
@@ -220,26 +201,7 @@ def get_train_args(
|
||||
)
|
||||
|
||||
# postprocess model_args
|
||||
if training_args.bf16:
|
||||
if not is_bf16_available:
|
||||
raise ValueError("Current device does not support bf16 training.")
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
elif training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
else:
|
||||
model_args.compute_dtype = _infer_dtype()
|
||||
|
||||
if model_args.layernorm_dtype == "bf16":
|
||||
if not is_bf16_available:
|
||||
raise ValueError("Current device does not support bf16 type.")
|
||||
model_args.layernorm_dtype = torch.bfloat16
|
||||
elif model_args.layernorm_dtype == "fp16":
|
||||
model_args.layernorm_dtype = torch.float16
|
||||
elif model_args.layernorm_dtype == "fp32":
|
||||
model_args.layernorm_dtype = torch.float32
|
||||
else:
|
||||
model_args.layernorm_dtype = model_args.compute_dtype
|
||||
|
||||
model_args.compute_dtype = torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
# Log on each process the small summary:
|
||||
@@ -278,7 +240,4 @@ def get_infer_args(
|
||||
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
|
||||
|
||||
# auto-detect cuda capability
|
||||
model_args.compute_dtype = _infer_dtype()
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
Reference in New Issue
Block a user