update web UI, support rm predict #210
Former-commit-id: 92cc6b655dc91b94d5bf9d8618c3b57d5cf94333
This commit is contained in:
@@ -3,7 +3,7 @@ import os
|
||||
import threading
|
||||
import time
|
||||
import transformers
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||
@@ -25,7 +25,9 @@ class Runner:
|
||||
self.aborted = True
|
||||
self.running = False
|
||||
|
||||
def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]:
|
||||
def initialize(
|
||||
self, lang: str, model_name: str, dataset: list
|
||||
) -> Tuple[str, str, LoggerHandler, LogCallback]:
|
||||
if self.running:
|
||||
return None, ALERTS["err_conflict"][lang], None, None
|
||||
|
||||
@@ -50,7 +52,9 @@ class Runner:
|
||||
|
||||
return model_name_or_path, "", logger_handler, trainer_callback
|
||||
|
||||
def finalize(self, lang: str, finish_info: Optional[str] = None) -> str:
|
||||
def finalize(
|
||||
self, lang: str, finish_info: Optional[str] = None
|
||||
) -> str:
|
||||
self.running = False
|
||||
torch_gc()
|
||||
if self.aborted:
|
||||
@@ -87,7 +91,7 @@ class Runner:
|
||||
lora_dropout: float,
|
||||
lora_target: str,
|
||||
output_dir: str
|
||||
):
|
||||
) -> Generator[str, None, None]:
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
@@ -174,7 +178,7 @@ class Runner:
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
predict: bool
|
||||
):
|
||||
) -> Generator[str, None, None]:
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
|
||||
Reference in New Issue
Block a user