improve aqlm optim

Former-commit-id: 81be999b407e988c2f42764d827ac859d079ed3e
This commit is contained in:
hiyouga
2024-03-05 20:49:50 +08:00
parent a10bead9b5
commit 46ee267cfc
4 changed files with 7 additions and 3 deletions

View File

@@ -226,6 +226,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
model_args.model_max_length = data_args.cutoff_len
model_args.aqlm_optimization = not training_args.predict_with_generate
# Log on each process the small summary:
logger.info(
@@ -262,6 +263,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
model_args.aqlm_optimization = True
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")