support rope scaling, fix #475 #476 #478

Former-commit-id: 337d5f68b72230e545e7a94ca789187c7a2b7187
This commit is contained in:
hiyouga
2023-08-12 20:46:27 +08:00
parent cde9f3db57
commit fdfb644f0a
12 changed files with 267 additions and 277 deletions

View File

@@ -151,13 +151,16 @@ def get_train_args(
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None:
if training_args.fp16:
model_args.compute_dtype = torch.float16
elif training_args.bf16:
model_args.compute_dtype = torch.bfloat16
else:
model_args.compute_dtype = torch.float32
if training_args.fp16:
model_args.compute_dtype = torch.float16
elif training_args.bf16:
if not torch.cuda.is_bf16_supported():
raise ValueError("Current device does not support bf16 training.")
model_args.compute_dtype = torch.bfloat16
else:
model_args.compute_dtype = torch.float32
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
# Log on each process the small summary:
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format(