[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -14,7 +14,8 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Optional
from transformers.utils import is_torch_npu_available
@@ -37,15 +38,12 @@ if is_gradio_available():
def _escape_html(text: str) -> str:
r"""
Escapes HTML characters.
"""
r"""Escape HTML characters."""
return text.replace("<", "&lt;").replace(">", "&gt;")
def _format_response(text: str, lang: str, escape_html: bool, thought_words: Tuple[str, str]) -> str:
r"""
Post-processes the response text.
def _format_response(text: str, lang: str, escape_html: bool, thought_words: tuple[str, str]) -> str:
r"""Post-process the response text.
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
"""
@@ -74,7 +72,7 @@ class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager
self.demo_mode = demo_mode
self.engine: Optional["BaseEngine"] = None
self.engine: Optional[BaseEngine] = None
if not lazy_init: # read arguments from command line
super().__init__()
@@ -160,14 +158,13 @@ class WebChatModel(ChatModel):
@staticmethod
def append(
chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]],
chatbot: list[dict[str, str]],
messages: list[dict[str, str]],
role: str,
query: str,
escape_html: bool,
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
r"""
Adds the user input to chatbot.
) -> tuple[list[dict[str, str]], list[dict[str, str]], str]:
r"""Add the user input to chatbot.
Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
Output: infer.chatbot, infer.messages, infer.query
@@ -180,8 +177,8 @@ class WebChatModel(ChatModel):
def stream(
self,
chatbot: List[Dict[str, str]],
messages: List[Dict[str, str]],
chatbot: list[dict[str, str]],
messages: list[dict[str, str]],
lang: str,
system: str,
tools: str,
@@ -193,9 +190,8 @@ class WebChatModel(ChatModel):
temperature: float,
skip_special_tokens: bool,
escape_html: bool,
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
r"""
Generates output text in stream.
) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
r"""Generate output text in stream.
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages

View File

@@ -17,7 +17,7 @@ import os
import signal
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union
from psutil import Process
from yaml import safe_dump, safe_load
@@ -44,9 +44,7 @@ USER_CONFIG = "user_config.yaml"
def abort_process(pid: int) -> None:
r"""
Aborts the processes recursively in a bottom-up way.
"""
r"""Abort the processes recursively in a bottom-up way."""
try:
children = Process(pid).children()
if children:
@@ -59,9 +57,7 @@ def abort_process(pid: int) -> None:
def get_save_dir(*paths: str) -> os.PathLike:
r"""
Gets the path to saved model checkpoints.
"""
r"""Get the path to saved model checkpoints."""
if os.path.sep in paths[-1]:
logger.warning_rank0("Found complex path, some features may be not available.")
return paths[-1]
@@ -71,16 +67,12 @@ def get_save_dir(*paths: str) -> os.PathLike:
def _get_config_path() -> os.PathLike:
r"""
Gets the path to user config.
"""
r"""Get the path to user config."""
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
r"""
Loads user config if exists.
"""
def load_config() -> dict[str, Union[str, dict[str, Any]]]:
r"""Load user config if exists."""
try:
with open(_get_config_path(), encoding="utf-8") as f:
return safe_load(f)
@@ -89,9 +81,7 @@ def load_config() -> Dict[str, Union[str, Dict[str, Any]]]:
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
r"""
Saves user config.
"""
r"""Save user config."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config()
user_config["lang"] = lang or user_config["lang"]
@@ -106,11 +96,9 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
def get_model_path(model_name: str) -> str:
r"""
Gets the model path according to the model name.
"""
r"""Get the model path according to the model name."""
user_config = load_config()
path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
path_dict: dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "")
if (
use_modelscope()
@@ -130,30 +118,22 @@ def get_model_path(model_name: str) -> str:
def get_template(model_name: str) -> str:
r"""
Gets the template name if the model is a chat/distill/instruct model.
"""
r"""Get the template name if the model is a chat/distill/instruct model."""
return DEFAULT_TEMPLATE.get(model_name, "default")
def get_time() -> str:
r"""
Gets current date and time.
"""
r"""Get current date and time."""
return datetime.now().strftime(r"%Y-%m-%d-%H-%M-%S")
def is_multimodal(model_name: str) -> bool:
r"""
Judges if the model is a vision language model.
"""
r"""Judge if the model is a vision language model."""
return model_name in MULTIMODAL_SUPPORTED_MODELS
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
r"""
Loads dataset_info.json.
"""
def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]:
r"""Load dataset_info.json."""
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.")
return {}
@@ -166,10 +146,8 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {}
def load_args(config_path: str) -> Optional[Dict[str, Any]]:
r"""
Loads the training configuration from config path.
"""
def load_args(config_path: str) -> Optional[dict[str, Any]]:
r"""Load the training configuration from config path."""
try:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
@@ -177,26 +155,20 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
return None
def save_args(config_path: str, config_dict: Dict[str, Any]) -> None:
r"""
Saves the training configuration to config path.
"""
def save_args(config_path: str, config_dict: dict[str, Any]) -> None:
r"""Save the training configuration to config path."""
with open(config_path, "w", encoding="utf-8") as f:
safe_dump(config_dict, f)
def _clean_cmd(args: Dict[str, Any]) -> Dict[str, Any]:
r"""
Removes args with NoneType or False or empty string value.
"""
def _clean_cmd(args: dict[str, Any]) -> dict[str, Any]:
r"""Remove args with NoneType or False or empty string value."""
no_skip_keys = ["packing"]
return {k: v for k, v in args.items() if (k in no_skip_keys) or (v is not None and v is not False and v != "")}
def gen_cmd(args: Dict[str, Any]) -> str:
r"""
Generates CLI commands for previewing.
"""
def gen_cmd(args: dict[str, Any]) -> str:
r"""Generate CLI commands for previewing."""
cmd_lines = ["llamafactory-cli train "]
for k, v in _clean_cmd(args).items():
if isinstance(v, dict):
@@ -215,10 +187,8 @@ def gen_cmd(args: Dict[str, Any]) -> str:
return cmd_text
def save_cmd(args: Dict[str, Any]) -> str:
r"""
Saves CLI commands to launch training.
"""
def save_cmd(args: dict[str, Any]) -> str:
r"""Save CLI commands to launch training."""
output_dir = args["output_dir"]
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
@@ -228,9 +198,7 @@ def save_cmd(args: Dict[str, Any]) -> str:
def load_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
r"""Get scores after evaluation."""
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
@@ -238,9 +206,7 @@ def load_eval_results(path: os.PathLike) -> str:
def create_ds_config() -> None:
r"""
Creates deepspeed config in the current directory.
"""
r"""Create deepspeed config in the current directory."""
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
ds_config = {
"train_batch_size": "auto",

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import json
from typing import TYPE_CHECKING, Dict, Tuple
from typing import TYPE_CHECKING
from ...data import Role
from ...extras.packages import is_gradio_available
@@ -31,9 +31,7 @@ if TYPE_CHECKING:
def check_json_schema(text: str, lang: str) -> None:
r"""
Checks if the json schema is valid.
"""
r"""Check if the json schema is valid."""
try:
tools = json.loads(text)
if tools:
@@ -49,7 +47,7 @@ def check_json_schema(text: str, lang: str) -> None:
def create_chat_box(
engine: "Engine", visible: bool = False
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
) -> tuple["Component", "Component", dict[str, "Component"]]:
lang = engine.manager.get_elem_by_id("top.lang")
with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(type="messages", show_copy_button=True)

View File

@@ -14,7 +14,7 @@
import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from typing import TYPE_CHECKING, Any
from ...extras.constants import DATA_CONFIG
from ...extras.packages import is_gradio_available
@@ -40,9 +40,7 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
r"""
Checks if the dataset is a local dataset.
"""
r"""Check if the dataset is a local dataset."""
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
@@ -59,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
return gr.Button(interactive=False)
def _load_data_file(file_path: str) -> List[Any]:
def _load_data_file(file_path: str) -> list[Any]:
with open(file_path, encoding="utf-8") as f:
if file_path.endswith(".json"):
return json.load(f)
@@ -69,10 +67,8 @@ def _load_data_file(file_path: str) -> List[Any]:
return list(f)
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
r"""
Gets the preview samples from the dataset.
"""
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> tuple[int, list, "gr.Column"]:
r"""Get the preview samples from the dataset."""
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
@@ -87,7 +83,7 @@ def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int,
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.Column(visible=True)
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> dict[str, "Component"]:
data_preview_btn = gr.Button(interactive=False, scale=1)
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
from ..engine import Engine
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
def create_eval_tab(engine: "Engine") -> dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()

View File

@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Generator, List, Union
from collections.abc import Generator
from typing import TYPE_CHECKING, Union
from ...extras.constants import PEFT_METHODS
from ...extras.misc import torch_gc
@@ -35,7 +36,7 @@ if TYPE_CHECKING:
GPTQ_BITS = ["8", "4", "3", "2"]
def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown":
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
return gr.Dropdown(value="none", interactive=False)
else:
@@ -47,7 +48,7 @@ def save_model(
model_name: str,
model_path: str,
finetuning_type: str,
checkpoint_path: Union[str, List[str]],
checkpoint_path: Union[str, list[str]],
template: str,
export_size: int,
export_quantization_bit: str,
@@ -106,7 +107,7 @@ def save_model(
yield ALERTS["info_exported"][lang]
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
def create_export_tab(engine: "Engine") -> dict[str, "Component"]:
with gr.Row():
export_size = gr.Slider(minimum=1, maximum=100, value=5, step=1)
export_quantization_bit = gr.Dropdown(choices=["none"] + GPTQ_BITS, value="none")

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from ...extras.packages import is_gradio_available
from ..common import is_multimodal
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
from ..engine import Engine
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
def create_infer_tab(engine: "Engine") -> dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from ...data import TEMPLATES
from ...extras.constants import METHODS, SUPPORTED_MODELS
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
from gradio.components import Component
def create_top() -> Dict[str, "Component"]:
def create_top() -> dict[str, "Component"]:
with gr.Row():
lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1)
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from transformers.trainer_utils import SchedulerType
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
from ..engine import Engine
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
input_elems = engine.manager.get_base_elems()
elem_dict = dict()
@@ -382,8 +382,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
resume_btn.change(engine.runner.monitor, outputs=output_elems, concurrency_limit=None)
lang = engine.manager.get_elem_by_id("top.lang")
model_name: "gr.Dropdown" = engine.manager.get_elem_by_id("top.model_name")
finetuning_type: "gr.Dropdown" = engine.manager.get_elem_by_id("top.finetuning_type")
model_name: gr.Dropdown = engine.manager.get_elem_by_id("top.model_name")
finetuning_type: gr.Dropdown = engine.manager.get_elem_by_id("top.finetuning_type")
arg_save_btn.click(engine.runner.save_args, input_elems, output_elems, concurrency_limit=None)
arg_load_btn.click(

View File

@@ -14,7 +14,7 @@
import json
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Optional
from transformers.trainer_utils import get_last_checkpoint
@@ -39,8 +39,7 @@ if is_gradio_available():
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
r"""
Judges if the quantization is available in this finetuning type.
r"""Judge if the quantization is available in this finetuning type.
Inputs: top.finetuning_type
Outputs: top.quantization_bit
@@ -52,8 +51,7 @@ def can_quantize(finetuning_type: str) -> "gr.Dropdown":
def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
r"""
Gets the available quantization bits.
r"""Get the available quantization bits.
Inputs: top.quantization_method
Outputs: top.quantization_bit
@@ -68,9 +66,8 @@ def can_quantize_to(quantization_method: str) -> "gr.Dropdown":
return gr.Dropdown(choices=available_bits)
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple[List[str], bool]:
r"""
Modifys states after changing the training stage.
def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> tuple[list[str], bool]:
r"""Modify states after changing the training stage.
Inputs: train.training_stage
Outputs: train.dataset, train.packing
@@ -78,9 +75,8 @@ def change_stage(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Tuple
return [], TRAINING_STAGES[training_stage] == "pt"
def get_model_info(model_name: str) -> Tuple[str, str]:
r"""
Gets the necessary information of this model.
def get_model_info(model_name: str) -> tuple[str, str]:
r"""Get the necessary information of this model.
Inputs: top.model_name
Outputs: top.model_path, top.template
@@ -88,9 +84,8 @@ def get_model_info(model_name: str) -> Tuple[str, str]:
return get_model_path(model_name), get_template(model_name)
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr.Slider", Dict[str, Any]]:
r"""
Gets training infomation for monitor.
def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> tuple[str, "gr.Slider", dict[str, Any]]:
r"""Get training infomation for monitor.
If do_train is True:
Inputs: top.lang, train.output_path
@@ -110,7 +105,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = []
trainer_log: list[dict[str, Any]] = []
with open(trainer_log_path, encoding="utf-8") as f:
for line in f:
trainer_log.append(json.loads(line))
@@ -143,8 +138,7 @@ def get_trainer_info(lang: str, output_path: os.PathLike, do_train: bool) -> Tup
def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
r"""
Lists all available checkpoints.
r"""List all available checkpoints.
Inputs: top.model_name, top.finetuning_type
Outputs: top.checkpoint_path
@@ -166,8 +160,7 @@ def list_checkpoints(model_name: str, finetuning_type: str) -> "gr.Dropdown":
def list_config_paths(current_time: str) -> "gr.Dropdown":
r"""
Lists all the saved configuration files.
r"""List all the saved configuration files.
Inputs: train.current_time
Outputs: train.config_path
@@ -182,8 +175,7 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
r"""
Lists all available datasets in the dataset dir for the training stage.
r"""List all available datasets in the dataset dir for the training stage.
Inputs: *.dataset_dir, *.training_stage
Outputs: *.dataset
@@ -195,8 +187,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
r"""
Lists all the directories that can resume from.
r"""List all the directories that can resume from.
Inputs: top.model_name, top.finetuning_type, train.current_time
Outputs: train.output_dir

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any
from .chatter import WebChatModel
from .common import create_ds_config, get_time, load_config
@@ -26,9 +26,7 @@ if TYPE_CHECKING:
class Engine:
r"""
A general engine to control the behaviors of Web UI.
"""
r"""A general engine to control the behaviors of Web UI."""
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
self.demo_mode = demo_mode
@@ -39,11 +37,9 @@ class Engine:
if not demo_mode:
create_ds_config()
def _update_component(self, input_dict: Dict[str, Dict[str, Any]]) -> Dict["Component", "Component"]:
r"""
Updates gradio components according to the (elem_id, properties) mapping.
"""
output_dict: Dict["Component", "Component"] = {}
def _update_component(self, input_dict: dict[str, dict[str, Any]]) -> dict["Component", "Component"]:
r"""Update gradio components according to the (elem_id, properties) mapping."""
output_dict: dict[Component, Component] = {}
for elem_id, elem_attr in input_dict.items():
elem = self.manager.get_elem_by_id(elem_id)
output_dict[elem] = elem.__class__(**elem_attr)
@@ -51,9 +47,7 @@ class Engine:
return output_dict
def resume(self):
r"""
Gets the initial value of gradio components and restores training status if necessary.
"""
r"""Get the initial value of gradio components and restores training status if necessary."""
user_config = load_config() if not self.demo_mode else {} # do not use config in demo mode
lang = user_config.get("lang", None) or "en"
init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
@@ -79,9 +73,7 @@ class Engine:
yield self._update_component({"eval.resume_btn": {"value": True}})
def change_lang(self, lang: str):
r"""
Updates the displayed language of gradio components.
"""
r"""Update the displayed language of gradio components."""
return {
elem: elem.__class__(**LOCALES[elem_name][lang])
for elem_name, elem in self.manager.get_elem_iter()

View File

@@ -48,7 +48,7 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
engine.manager.add_elems("top", create_top())
lang: "gr.Dropdown" = engine.manager.get_elem_by_id("top.lang")
lang: gr.Dropdown = engine.manager.get_elem_by_id("top.lang")
with gr.Tab("Train"):
engine.manager.add_elems("train", create_train_tab(engine))

View File

@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Generator, List, Set, Tuple
from collections.abc import Generator
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -20,54 +21,41 @@ if TYPE_CHECKING:
class Manager:
r"""
A class to manage all the gradio components in Web UI.
"""
r"""A class to manage all the gradio components in Web UI."""
def __init__(self) -> None:
self._id_to_elem: Dict[str, "Component"] = {}
self._elem_to_id: Dict["Component", str] = {}
self._id_to_elem: dict[str, Component] = {}
self._elem_to_id: dict[Component, str] = {}
def add_elems(self, tab_name: str, elem_dict: Dict[str, "Component"]) -> None:
r"""
Adds elements to manager.
"""
def add_elems(self, tab_name: str, elem_dict: dict[str, "Component"]) -> None:
r"""Add elements to manager."""
for elem_name, elem in elem_dict.items():
elem_id = f"{tab_name}.{elem_name}"
self._id_to_elem[elem_id] = elem
self._elem_to_id[elem] = elem_id
def get_elem_list(self) -> List["Component"]:
r"""
Returns the list of all elements.
"""
def get_elem_list(self) -> list["Component"]:
r"""Return the list of all elements."""
return list(self._id_to_elem.values())
def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]:
r"""
Returns an iterator over all elements with their names.
"""
def get_elem_iter(self) -> Generator[tuple[str, "Component"], None, None]:
r"""Return an iterator over all elements with their names."""
for elem_id, elem in self._id_to_elem.items():
yield elem_id.split(".")[-1], elem
def get_elem_by_id(self, elem_id: str) -> "Component":
r"""
Gets element by id.
r"""Get element by id.
Example: top.lang, train.dataset
"""
return self._id_to_elem[elem_id]
def get_id_by_elem(self, elem: "Component") -> str:
r"""
Gets id by element.
"""
r"""Get id by element."""
return self._elem_to_id[elem]
def get_base_elems(self) -> Set["Component"]:
r"""
Gets the base elements that are commonly used.
"""
def get_base_elems(self) -> set["Component"]:
r"""Get the base elements that are commonly used."""
return {
self._id_to_elem["top.lang"],
self._id_to_elem["top.model_name"],

View File

@@ -14,9 +14,10 @@
import json
import os
from collections.abc import Generator
from copy import deepcopy
from subprocess import Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
from typing import TYPE_CHECKING, Any, Optional
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_npu_available
@@ -51,17 +52,16 @@ if TYPE_CHECKING:
class Runner:
r"""
A class to manage the running status of the trainers.
"""
r"""A class to manage the running status of the trainers."""
def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
r"""Init a runner."""
self.manager = manager
self.demo_mode = demo_mode
""" Resume """
self.trainer: Optional["Popen"] = None
self.trainer: Optional[Popen] = None
self.do_train = True
self.running_data: Dict["Component", Any] = None
self.running_data: dict[Component, Any] = None
""" State """
self.aborted = False
self.running = False
@@ -71,10 +71,8 @@ class Runner:
if self.trainer is not None:
abort_process(self.trainer.pid)
def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
r"""
Validates the configuration.
"""
def _initialize(self, data: dict["Component", Any], do_train: bool, from_preview: bool) -> str:
r"""Validate the configuration."""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
dataset = get("train.dataset") if do_train else get("eval.dataset")
@@ -116,9 +114,7 @@ class Runner:
return ""
def _finalize(self, lang: str, finish_info: str) -> str:
r"""
Cleans the cached memory and resets the runner.
"""
r"""Clean the cached memory and resets the runner."""
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
gr.Info(finish_info)
self.trainer = None
@@ -128,10 +124,8 @@ class Runner:
torch_gc()
return finish_info
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
r"""
Builds and validates the training arguments.
"""
def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]:
r"""Build and validate the training arguments."""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
@@ -291,10 +285,8 @@ class Runner:
return args
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
r"""
Builds and validates the evaluation arguments.
"""
def _parse_eval_args(self, data: dict["Component", Any]) -> dict[str, Any]:
r"""Build and validate the evaluation arguments."""
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
@@ -345,10 +337,8 @@ class Runner:
return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
r"""
Previews the training commands.
"""
def _preview(self, data: dict["Component", Any], do_train: bool) -> Generator[dict["Component", str], None, None]:
r"""Preview the training commands."""
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=True)
if error:
@@ -358,10 +348,8 @@ class Runner:
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
yield {output_box: gen_cmd(args)}
def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
r"""
Starts the training process.
"""
def _launch(self, data: dict["Component", Any], do_train: bool) -> Generator[dict["Component", Any], None, None]:
r"""Start the training process."""
output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
error = self._initialize(data, do_train, from_preview=False)
if error:
@@ -383,10 +371,8 @@ class Runner:
self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env)
yield from self.monitor()
def _build_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
r"""
Builds a dictionary containing the current training configuration.
"""
def _build_config_dict(self, data: dict["Component", Any]) -> dict[str, Any]:
r"""Build a dictionary containing the current training configuration."""
config_dict = {}
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
for elem, value in data.items():
@@ -409,9 +395,7 @@ class Runner:
yield from self._launch(data, do_train=False)
def monitor(self):
r"""
Monitors the training progress and logs.
"""
r"""Monitorgit the training progress and logs."""
self.aborted = False
self.running = True
@@ -469,9 +453,7 @@ class Runner:
yield return_dict
def save_args(self, data):
r"""
Saves the training configuration to config path.
"""
r"""Save the training configuration to config path."""
output_box = self.manager.get_elem_by_id("train.output_box")
error = self._initialize(data, do_train=True, from_preview=True)
if error:
@@ -487,27 +469,23 @@ class Runner:
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
def load_args(self, lang: str, config_path: str):
r"""
Loads the training configuration from config path.
"""
r"""Load the training configuration from config path."""
output_box = self.manager.get_elem_by_id("train.output_box")
config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path))
if config_dict is None:
gr.Warning(ALERTS["err_config_not_found"][lang])
return {output_box: ALERTS["err_config_not_found"][lang]}
output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]}
output_dict: dict[Component, Any] = {output_box: ALERTS["info_config_loaded"][lang]}
for elem_id, value in config_dict.items():
output_dict[self.manager.get_elem_by_id(elem_id)] = value
return output_dict
def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str):
r"""
Restore the training status if output_dir exists.
"""
r"""Restore the training status if output_dir exists."""
output_box = self.manager.get_elem_by_id("train.output_box")
output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]}
output_dict: dict[Component, Any] = {output_box: LOCALES["output_box"][lang]["value"]}
if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
gr.Warning(ALERTS["warn_output_dir_exists"][lang])
output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang]