support function calling
Former-commit-id: 66533b3f65babf2429c92c0f8fafe4eff5e0ff63
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
import torch
|
||||
import tiktoken
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from llmtuner.data.template import get_template_and_fix_tokenizer
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import dispatch_model, load_model_and_tokenizer
|
||||
from ..hparams import get_infer_args
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -139,11 +139,6 @@ class ChatModel:
|
||||
batch_input: List[str],
|
||||
**input_kwargs
|
||||
) -> List[float]:
|
||||
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||
kwargs = dict(allowed_special="all")
|
||||
else:
|
||||
kwargs = dict(add_special_tokens=True)
|
||||
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
device = getattr(self.model.pretrained_model, "device", "cuda")
|
||||
|
||||
@@ -153,7 +148,7 @@ class ChatModel:
|
||||
truncation=True,
|
||||
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
||||
return_tensors="pt",
|
||||
**kwargs
|
||||
add_special_tokens=True
|
||||
).to(device)
|
||||
|
||||
input_ids: torch.Tensor = inputs["input_ids"]
|
||||
|
||||
Reference in New Issue
Block a user