modify code structure
Former-commit-id: 6369f9b1751e6f9bb709ba76a85f69cbe0823e5d
This commit is contained in:
@@ -1,22 +1,21 @@
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.hparams import GeneratingArguments
|
||||
from llmtuner.tuner import get_infer_args
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
|
||||
def __init__(self, *args):
|
||||
def __init__(self, args: Optional[Dict[str, Any]]) -> None:
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
if len(args) != 0:
|
||||
super().__init__(*args)
|
||||
if args is not None:
|
||||
super().__init__(args)
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
@@ -57,7 +56,7 @@ class WebChatModel(ChatModel):
|
||||
template=template,
|
||||
source_prefix=source_prefix
|
||||
)
|
||||
super().__init__(*get_infer_args(args))
|
||||
super().__init__(args)
|
||||
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user