add cal_lr.py

Former-commit-id: cea2ba17efc47917e63437a376f220864f7f90dd
This commit is contained in:
hiyouga
2023-11-14 20:58:37 +08:00
parent c9a4551012
commit 75dd1f0f7e
4 changed files with 67 additions and 6 deletions

View File

@@ -84,10 +84,9 @@ def load_model_and_tokenizer(
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
# Set model dtype
if model_args.compute_dtype is not None: # for training
setattr(config, "torch_dtype", model_args.compute_dtype)
else: # for evaluation, priority: bf16 > fp16 > fp32
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
setattr(config, "torch_dtype", model_args.compute_dtype)
# Fix config (for Qwen)
if getattr(config, "model_type", None) == "qwen":