Former-commit-id: 80a346e29beb49e8935b786e2af1059fdc4954b2
This commit is contained in:
hiyouga
2023-07-25 17:04:02 +08:00
parent c145bbef3c
commit ac587438f8
4 changed files with 22 additions and 3 deletions

View File

@@ -19,6 +19,14 @@ class ChatModel:
generating_args: GeneratingArguments
) -> None:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(self.model)
self.model = dispatch_model(self.model, device_map)
else:
self.model = self.model.cuda()
self.template = get_template(data_args.prompt_template)
self.source_prefix = data_args.source_prefix or ""
self.generating_args = generating_args
@@ -32,6 +40,7 @@ class ChatModel:
inputs = inputs.to(self.model.device)
prompt_length = len(inputs["input_ids"][0])
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
@@ -42,6 +51,7 @@ class ChatModel:
gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict(
input_ids=inputs["input_ids"],
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
temperature=temperature or gen_kwargs["temperature"],
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],