[data] fix ollama template (#6902)
* fix ollama template * add meta info * use half precision Former-commit-id: 1304bbea69d8c8ca57140017515dee7ae2ee6536
This commit is contained in:
@@ -321,10 +321,11 @@ class Template:
|
||||
|
||||
TODO: support function calling.
|
||||
"""
|
||||
modelfile = f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
|
||||
modelfile = "# ollama modelfile auto-generated by llamafactory\n\n"
|
||||
modelfile += f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
|
||||
|
||||
if self.default_system:
|
||||
modelfile += f'SYSTEM system "{self.default_system}"\n\n'
|
||||
modelfile += f'SYSTEM """{self.default_system}"""\n\n'
|
||||
|
||||
for stop_token_id in self.get_stop_token_ids(tokenizer):
|
||||
modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n'
|
||||
|
||||
@@ -22,6 +22,7 @@ from transformers import PreTrainedModel
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import infer_optim_dtype
|
||||
from ..extras.packages import is_ray_available
|
||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
@@ -117,7 +118,9 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
setattr(model.config, "torch_dtype", torch.float16)
|
||||
else:
|
||||
if model_args.infer_dtype == "auto":
|
||||
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
|
||||
output_dtype = getattr(model.config, "torch_dtype", torch.float32)
|
||||
if output_dtype == torch.float32: # if infer_dtype is auto, try using half precision first
|
||||
output_dtype = infer_optim_dtype(torch.bfloat16)
|
||||
else:
|
||||
output_dtype = getattr(torch, model_args.infer_dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user