support function calling

Former-commit-id: 66533b3f65babf2429c92c0f8fafe4eff5e0ff63
This commit is contained in:
hiyouga
2024-01-18 09:54:23 +08:00
parent f7329b1a0e
commit a423274fd9
67 changed files with 1239 additions and 1079 deletions

View File

@@ -3,7 +3,6 @@
import os
import json
import torch
import tiktoken
import numpy as np
from tqdm import tqdm, trange
from typing import Any, Dict, List, Optional
@@ -11,10 +10,11 @@ from typing import Any, Dict, List, Optional
from datasets import load_dataset
from transformers.utils import cached_file
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.eval.template import get_eval_template
from llmtuner.extras.constants import CHOICES, SUBJECTS
from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer
from ..data import get_template_and_fix_tokenizer
from .template import get_eval_template
from ..extras.constants import CHOICES, SUBJECTS
from ..hparams import get_eval_args
from ..model import dispatch_model, load_model_and_tokenizer
class Evaluator:
@@ -26,15 +26,9 @@ class Evaluator:
self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = self._encode_choices()
def _encode_choices(self) -> List[int]:
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
self.choice_inputs = [self.tokenizer.encode(
self.eval_template.prefix + ch, add_special_tokens=False
)[-1] for ch in CHOICES]
@torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: