Former-commit-id: e0b0c4415aaf80e75f6dd4f3777a0616b0e60f84
This commit is contained in:
hiyouga
2024-01-24 16:19:18 +08:00
parent 8947a87b95
commit 1ace676170
3 changed files with 43 additions and 38 deletions

View File

@@ -56,7 +56,9 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
if hasattr(model.config, "torch_dtype"):
if getattr(model, "quantization_method", None):
model = model.to("cpu")
elif hasattr(model.config, "torch_dtype"):
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
else:
model = model.to(torch.float16).to("cpu")