update web UI, support rm predict #210

Former-commit-id: 92cc6b655dc91b94d5bf9d8618c3b57d5cf94333
This commit is contained in:
hiyouga
2023-07-21 13:27:27 +08:00
parent c4e9694c6e
commit 0f7cdac207
13 changed files with 192 additions and 27 deletions

View File

@@ -3,7 +3,7 @@ import os
import threading
import time
import transformers
from typing import List, Optional, Tuple
from typing import Generator, List, Optional, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE
@@ -25,7 +25,9 @@ class Runner:
self.aborted = True
self.running = False
def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]:
def initialize(
self, lang: str, model_name: str, dataset: list
) -> Tuple[str, str, LoggerHandler, LogCallback]:
if self.running:
return None, ALERTS["err_conflict"][lang], None, None
@@ -50,7 +52,9 @@ class Runner:
return model_name_or_path, "", logger_handler, trainer_callback
def finalize(self, lang: str, finish_info: Optional[str] = None) -> str:
def finalize(
self, lang: str, finish_info: Optional[str] = None
) -> str:
self.running = False
torch_gc()
if self.aborted:
@@ -87,7 +91,7 @@ class Runner:
lora_dropout: float,
lora_target: str,
output_dir: str
):
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
@@ -174,7 +178,7 @@ class Runner:
max_samples: str,
batch_size: int,
predict: bool
):
) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error