use pre-commit

Former-commit-id: 7cfede95df22a9ff236788f04159b6b16b8d04bb
This commit is contained in:
hiyouga
2024-10-29 09:07:46 +00:00
parent 8f5921692e
commit 248d5daaff
66 changed files with 1028 additions and 1044 deletions

View File

@@ -75,7 +75,7 @@ def load_config() -> Dict[str, Any]:
Loads user config if exists.
"""
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
with open(get_config_path(), encoding="utf-8") as f:
return safe_load(f)
except Exception:
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
@@ -172,14 +172,14 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
Loads dataset_info.json.
"""
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
logger.info("dataset_dir is {}, using online dataset.".format(dataset_dir))
logger.info(f"dataset_dir is {dataset_dir}, using online dataset.")
return {}
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
return json.load(f)
except Exception as err:
logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
logger.warning(f"Cannot open {os.path.join(dataset_dir, DATA_CONFIG)} due to {str(err)}.")
return {}

View File

@@ -41,7 +41,7 @@ def next_page(page_index: int, total_num: int) -> int:
def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
try:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
except Exception:
return gr.Button(interactive=False)
@@ -57,7 +57,7 @@ def can_preview(dataset_dir: str, dataset: list) -> "gr.Button":
def _load_data_file(file_path: str) -> List[Any]:
with open(file_path, "r", encoding="utf-8") as f:
with open(file_path, encoding="utf-8") as f:
if file_path.endswith(".json"):
return json.load(f)
elif file_path.endswith(".jsonl"):
@@ -67,7 +67,7 @@ def _load_data_file(file_path: str) -> List[Any]:
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, "gr.Column"]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
dataset_info = json.load(f)
data_path = os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])

View File

@@ -56,9 +56,9 @@ class Engine:
if not self.pure_chat:
current_time = get_time()
init_dict["train.current_time"] = {"value": current_time}
init_dict["train.output_dir"] = {"value": "train_{}".format(current_time)}
init_dict["train.config_path"] = {"value": "{}.yaml".format(current_time)}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(current_time)}
init_dict["train.output_dir"] = {"value": f"train_{current_time}"}
init_dict["train.config_path"] = {"value": f"{current_time}.yaml"}
init_dict["eval.output_dir"] = {"value": f"eval_{current_time}"}
init_dict["infer.mm_box"] = {"visible": False}
if user_config.get("last_model", None):

View File

@@ -29,7 +29,7 @@ class Manager:
Adds elements to manager.
"""
for elem_name, elem in elem_dict.items():
elem_id = "{}.{}".format(tab_name, elem_name)
elem_id = f"{tab_name}.{elem_name}"
self._id_to_elem[elem_id] = elem
self._elem_to_id[elem] = elem_id

View File

@@ -231,7 +231,7 @@ class Runner:
if get("train.ds_stage") != "none":
ds_stage = get("train.ds_stage")
ds_offload = "offload_" if get("train.ds_offload") else ""
args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, "ds_z{}_{}config.json".format(ds_stage, ds_offload))
args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, f"ds_z{ds_stage}_{ds_offload}config.json")
return args
@@ -313,7 +313,7 @@ class Runner:
if args.get("deepspeed", None) is not None:
env["FORCE_TORCHRUN"] = "1"
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
self.trainer = Popen(f"llamafactory-cli train {save_cmd(args)}", env=env, shell=True)
yield from self.monitor()
def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:

View File

@@ -111,14 +111,14 @@ def gen_cmd(args: Dict[str, Any]) -> str:
"""
cmd_lines = ["llamafactory-cli train "]
for k, v in clean_cmd(args).items():
cmd_lines.append(" --{} {} ".format(k, str(v)))
cmd_lines.append(f" --{k} {str(v)} ")
if os.name == "nt":
cmd_text = "`\n".join(cmd_lines)
else:
cmd_text = "\\\n".join(cmd_lines)
cmd_text = "```bash\n{}\n```".format(cmd_text)
cmd_text = f"```bash\n{cmd_text}\n```"
return cmd_text
@@ -139,9 +139,9 @@ def get_eval_results(path: os.PathLike) -> str:
r"""
Gets scores after evaluation.
"""
with open(path, "r", encoding="utf-8") as f:
with open(path, encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4)
return "```json\n{}\n```\n".format(result)
return f"```json\n{result}\n```\n"
def get_time() -> str:
@@ -161,13 +161,13 @@ def get_trainer_info(output_path: os.PathLike, do_train: bool) -> Tuple[str, "gr
running_log_path = os.path.join(output_path, RUNNING_LOG)
if os.path.isfile(running_log_path):
with open(running_log_path, "r", encoding="utf-8") as f:
with open(running_log_path, encoding="utf-8") as f:
running_log = f.read()
trainer_log_path = os.path.join(output_path, TRAINER_LOG)
if os.path.isfile(trainer_log_path):
trainer_log: List[Dict[str, Any]] = []
with open(trainer_log_path, "r", encoding="utf-8") as f:
with open(trainer_log_path, encoding="utf-8") as f:
for line in f:
trainer_log.append(json.loads(line))
@@ -193,7 +193,7 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
Loads saved arguments.
"""
try:
with open(config_path, "r", encoding="utf-8") as f:
with open(config_path, encoding="utf-8") as f:
return safe_load(f)
except Exception:
return None
@@ -211,7 +211,7 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
r"""
Lists all the saved configuration files.
"""
config_files = ["{}.yaml".format(current_time)]
config_files = [f"{current_time}.yaml"]
if os.path.isdir(DEFAULT_CONFIG_DIR):
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
if file_name.endswith(".yaml") and file_name not in config_files:
@@ -224,7 +224,7 @@ def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_ti
r"""
Lists all the directories that can resume from.
"""
output_dirs = ["train_{}".format(current_time)]
output_dirs = [f"train_{current_time}"]
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):