format style

Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent 1750218057
commit 66e0e651b9
73 changed files with 1492 additions and 2325 deletions

View File

@@ -1,18 +1,18 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer, Role
from ..data import Role, get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor
from ..model import dispatch_model, load_model_and_tokenizer
from ..hparams import get_infer_args
from ..model import dispatch_model, load_model_and_tokenizer
@dataclass
class Response:
response_text: str
response_length: int
prompt_length: int
@@ -20,10 +20,9 @@ class Response:
class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.can_generate = (finetuning_args.stage == "sft")
self.can_generate = finetuning_args.stage == "sft"
self.model, self.tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
@@ -37,7 +36,7 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs
**input_kwargs,
) -> Tuple[Dict[str, Any], int]:
messages = []
if history is not None:
@@ -63,16 +62,18 @@ class ChatModel:
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
generating_args = self.generating_args.to_dict()
generating_args.update(dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id
))
generating_args.update(
dict(
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature or generating_args["temperature"],
top_p=top_p or generating_args["top_p"],
top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id,
)
)
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
generating_args["do_sample"] = True
@@ -88,7 +89,7 @@ class ChatModel:
gen_kwargs = dict(
inputs=input_ids,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor()
logits_processor=get_logits_processor(),
)
return gen_kwargs, prompt_length
@@ -100,7 +101,7 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs
**input_kwargs,
) -> List[Response]:
r"""
Args: query, history, system, **input_kwargs
@@ -117,12 +118,14 @@ class ChatModel:
for i in range(len(response)):
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i])
results.append(Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length"
))
results.append(
Response(
response_text=response[i],
response_length=response_length,
prompt_length=prompt_length,
finish_reason="stop" if len(eos_index) else "length",
)
)
return results
@@ -133,7 +136,7 @@ class ChatModel:
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs
**input_kwargs,
) -> Generator[str, None, None]:
gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
@@ -145,11 +148,7 @@ class ChatModel:
yield from streamer
@torch.inference_mode()
def get_scores(
self,
batch_input: List[str],
**input_kwargs
) -> List[float]:
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda")
@@ -159,7 +158,7 @@ class ChatModel:
truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
return_tensors="pt",
add_special_tokens=True
add_special_tokens=True,
).to(device)
input_ids: torch.Tensor = inputs["input_ids"]