support function calling

Former-commit-id: 66533b3f65babf2429c92c0f8fafe4eff5e0ff63
This commit is contained in:
hiyouga
2024-01-18 09:54:23 +08:00
parent f7329b1a0e
commit a423274fd9
67 changed files with 1239 additions and 1079 deletions

View File

@@ -1,13 +1,13 @@
import torch
import tiktoken
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread
from transformers import GenerationConfig, TextIteratorStreamer
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor
from ..model import dispatch_model, load_model_and_tokenizer
from ..hparams import get_infer_args
@dataclass
@@ -139,11 +139,6 @@ class ChatModel:
batch_input: List[str],
**input_kwargs
) -> List[float]:
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=True)
max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda")
@@ -153,7 +148,7 @@ class ChatModel:
truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
return_tensors="pt",
**kwargs
add_special_tokens=True
).to(device)
input_ids: torch.Tensor = inputs["input_ids"]