modify code structure

Former-commit-id: 6369f9b1751e6f9bb709ba76a85f69cbe0823e5d
This commit is contained in:
hiyouga
2023-08-02 23:17:36 +08:00
parent 8bd1da7144
commit 28a51b622b
25 changed files with 188 additions and 145 deletions

View File

@@ -1,30 +1,21 @@
import torch
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template
from llmtuner.tuner import load_model_and_tokenizer
if TYPE_CHECKING:
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
class ChatModel:
def __init__(
self,
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments"
) -> None:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model)
self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix
self.generating_args = generating_args
def process_args(
self,