improve lora+ impl.

Former-commit-id: 332bad25455a70ad9204e7dd384bb086d789aa39
This commit is contained in:
hiyouga
2024-03-13 23:32:51 +08:00
parent 73f4513c84
commit 46f99ff277
12 changed files with 165 additions and 169 deletions

View File

@@ -43,8 +43,10 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
def export_model(args: Optional[Dict[str, Any]] = None):
model_args, data_args, finetuning_args, _ = get_infer_args(args)
model_args.device_map = {"": "cpu"}
if model_args.export_dir is None:
raise ValueError("Please specify `export_dir`.")
raise ValueError("Please specify `export_dir` to save model.")
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
raise ValueError("Please merge adapters before quantizing the model.")
@@ -58,13 +60,10 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
if getattr(model, "quantization_method", None):
model = model.to("cpu")
elif hasattr(model.config, "torch_dtype"):
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
else:
model = model.to(torch.float16).to("cpu")
setattr(model.config, "torch_dtype", torch.float16)
if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
model = model.to(output_dtype)
setattr(model.config, "torch_dtype", output_dtype)
model.save_pretrained(
save_directory=model_args.export_dir,