Former-commit-id: 09762d9849655f5e6c71b9472d55b42489dd944b
This commit is contained in:
hiyouga
2023-08-01 18:13:03 +08:00
parent cb4d1d5ebb
commit 250fecfcd4
2 changed files with 24 additions and 10 deletions

View File

@@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template
from llmtuner.tuner import load_model_and_tokenizer
@@ -21,15 +21,7 @@ class ChatModel:
generating_args: "GeneratingArguments"
) -> None:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model))
self.model = dispatch_model(self.model, device_map)
else:
self.model = self.model.cuda()
self.model = dispatch_model(self.model)
self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix
self.generating_args = generating_args