Former-commit-id: 819cc1353599e5fa45658bc56dd0dbe4b258b197
This commit is contained in:
@@ -1,42 +1,50 @@
|
||||
import torch
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.template import get_template
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
from llmtuner.tuner import load_model_and_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
class ChatModel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
generating_args: GeneratingArguments
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments"
|
||||
) -> None:
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
device_map = infer_auto_device_map(self.model)
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model))
|
||||
self.model = dispatch_model(self.model, device_map)
|
||||
else:
|
||||
self.model = self.model.cuda()
|
||||
|
||||
self.template = get_template(data_args.prompt_template)
|
||||
self.source_prefix = data_args.source_prefix or ""
|
||||
self.template = get_template(data_args.template)
|
||||
self.source_prefix = data_args.source_prefix
|
||||
self.generating_args = generating_args
|
||||
|
||||
def process_args(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
prefix = prefix or self.source_prefix
|
||||
|
||||
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt")
|
||||
prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token)
|
||||
inputs = self.tokenizer([prompt], return_tensors="pt")
|
||||
inputs = inputs.to(self.model.device)
|
||||
prompt_length = len(inputs["input_ids"][0])
|
||||
|
||||
@@ -71,7 +79,11 @@ class ChatModel:
|
||||
|
||||
@torch.inference_mode()
|
||||
def chat(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[str, Tuple[int, int]]:
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
|
||||
generation_output = self.model.generate(**gen_kwargs)
|
||||
@@ -82,7 +94,11 @@ class ChatModel:
|
||||
|
||||
@torch.inference_mode()
|
||||
def stream_chat(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
Reference in New Issue
Block a user