modify code structure
Former-commit-id: 6369f9b1751e6f9bb709ba76a85f69cbe0823e5d
This commit is contained in:
@@ -1,30 +1,21 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
||||
from llmtuner.extras.template import get_template
|
||||
from llmtuner.tuner import load_model_and_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
||||
|
||||
|
||||
class ChatModel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments"
|
||||
) -> None:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
self.model = dispatch_model(self.model)
|
||||
self.template = get_template(data_args.template)
|
||||
self.source_prefix = data_args.source_prefix
|
||||
self.generating_args = generating_args
|
||||
|
||||
def process_args(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user