[data] fix ollama template (#6902)

* fix ollama template

* add meta info

* use half precision

Former-commit-id: 1304bbea69d8c8ca57140017515dee7ae2ee6536
This commit is contained in:
hoshi-hiyouga
2025-02-11 22:43:09 +08:00
committed by GitHub
parent 88eafd865b
commit 86063e27ea
4 changed files with 8 additions and 4 deletions

View File

@@ -22,6 +22,7 @@ from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import infer_optim_dtype
from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer
@@ -117,7 +118,9 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
setattr(model.config, "torch_dtype", torch.float16)
else:
if model_args.infer_dtype == "auto":
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
output_dtype = getattr(model.config, "torch_dtype", torch.float32)
if output_dtype == torch.float32: # if infer_dtype is auto, try using half precision first
output_dtype = infer_optim_dtype(torch.bfloat16)
else:
output_dtype = getattr(torch, model_args.infer_dtype)