improve lora+ impl.
Former-commit-id: 332bad25455a70ad9204e7dd384bb086d789aa39
This commit is contained in:
@@ -43,8 +43,10 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra
|
||||
def export_model(args: Optional[Dict[str, Any]] = None):
|
||||
model_args, data_args, finetuning_args, _ = get_infer_args(args)
|
||||
|
||||
model_args.device_map = {"": "cpu"}
|
||||
|
||||
if model_args.export_dir is None:
|
||||
raise ValueError("Please specify `export_dir`.")
|
||||
raise ValueError("Please specify `export_dir` to save model.")
|
||||
|
||||
if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
|
||||
raise ValueError("Please merge adapters before quantizing the model.")
|
||||
@@ -58,13 +60,10 @@ def export_model(args: Optional[Dict[str, Any]] = None):
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
|
||||
|
||||
if getattr(model, "quantization_method", None):
|
||||
model = model.to("cpu")
|
||||
elif hasattr(model.config, "torch_dtype"):
|
||||
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
|
||||
else:
|
||||
model = model.to(torch.float16).to("cpu")
|
||||
setattr(model.config, "torch_dtype", torch.float16)
|
||||
if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
|
||||
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
|
||||
model = model.to(output_dtype)
|
||||
setattr(model.config, "torch_dtype", output_dtype)
|
||||
|
||||
model.save_pretrained(
|
||||
save_directory=model_args.export_dir,
|
||||
|
||||
Reference in New Issue
Block a user