format style

Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent 1750218057
commit 66e0e651b9
73 changed files with 1492 additions and 2325 deletions

View File

@@ -1,12 +1,12 @@
import logging
import os
import time
import logging
import gradio as gr
from threading import Thread
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
import gradio as gr
import transformers
from gradio.components import Component # cannot use TYPE_CHECKING here
from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.callbacks import LogCallback
@@ -18,12 +18,12 @@ from .common import get_module, get_save_dir, load_config
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, update_process_bar
if TYPE_CHECKING:
from .manager import Manager
class Runner:
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
self.manager = manager
self.demo_mode = demo_mode
@@ -90,9 +90,12 @@ class Runner:
user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")])
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
@@ -131,12 +134,12 @@ class Runner:
create_new_adapter=get("train.create_new_adapter"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16")
bf16=(get("train.compute_type") == "bf16"),
)
args["disable_tqdm"] = True
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
args["create_new_adapter"] = (args["quantization_bit"] is None)
args["create_new_adapter"] = args["quantization_bit"] is None
if args["stage"] == "ppo":
args["reward_model"] = get_save_dir(
@@ -161,9 +164,12 @@ class Runner:
user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")])
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
@@ -187,7 +193,7 @@ class Runner:
max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"),
temperature=get("eval.temperature"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir"))
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
)
if get("eval.predict"):
@@ -197,7 +203,9 @@ class Runner:
return args
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
def _preview(
self, data: Dict[Component, Any], do_train: bool
) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
error = self._initialize(data, do_train, from_preview=True)
if error:
gr.Warning(error)
@@ -235,9 +243,11 @@ class Runner:
get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
self.running = True
lang = get("top.lang")
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
"{}.output_dir".format("train" if self.do_train else "eval")
))
output_dir = get_save_dir(
get("top.model_name"),
get("top.finetuning_type"),
get("{}.output_dir".format("train" if self.do_train else "eval")),
)
while self.thread.is_alive():
time.sleep(2)