better llamaboard

* easily resume from checkpoint
* support full and freeze checkpoints
* faster ui


Former-commit-id: 84cfb2452cc86b037ccddee6e833f8eb7c129fa4
This commit is contained in:
hiyouga
2024-05-29 23:55:38 +08:00
parent f90c4ca672
commit 87aa332583
14 changed files with 303 additions and 193 deletions

View File

@@ -7,12 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
import psutil
from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import TRAINING_STAGES
from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc
from ..extras.packages import is_gradio_available
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_args, load_config, save_args
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd
from .utils import gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
if is_gradio_available():
@@ -85,26 +85,16 @@ class Runner:
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
args = dict(
stage=TRAINING_STAGES[get("train.training_stage")],
do_train=True,
model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=get("top.finetuning_type"),
finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
@@ -134,13 +124,23 @@ class Runner:
report_to="all" if get("train.report_to") else "none",
use_galore=get("train.use_galore"),
use_badam=get("train.use_badam"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"),
plot_loss=True,
ddp_timeout=180000000,
)
# checkpoints
if get("top.checkpoint_path"):
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
)
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
# freeze config
if args["finetuning_type"] == "freeze":
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
@@ -156,7 +156,7 @@ class Runner:
args["create_new_adapter"] = get("train.create_new_adapter")
args["use_rslora"] = get("train.use_rslora")
args["use_dora"] = get("train.use_dora")
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
args["lora_target"] = get("train.lora_target") or get_module(model_name)
args["additional_target"] = get("train.additional_target") or None
if args["use_llama_pro"]:
@@ -164,13 +164,14 @@ class Runner:
# rlhf config
if args["stage"] == "ppo":
args["reward_model"] = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("train.reward_model")
]
)
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
if finetuning_type in PEFT_METHODS:
args["reward_model"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("train.reward_model")]
)
else:
args["reward_model"] = get_save_dir(model_name, finetuning_type, get("train.reward_model"))
args["reward_model_type"] = "lora" if finetuning_type == "lora" else "full"
args["ppo_score_norm"] = get("train.ppo_score_norm")
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
args["top_k"] = 0
@@ -211,25 +212,15 @@ class Runner:
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
if get("top.adapter_path"):
adapter_name_or_path = ",".join(
[
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")
]
)
else:
adapter_name_or_path = None
args = dict(
stage="sft",
model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=get("top.finetuning_type"),
finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
@@ -245,7 +236,7 @@ class Runner:
max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"),
temperature=get("eval.temperature"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
)
if get("eval.predict"):
@@ -253,6 +244,14 @@ class Runner:
else:
args["do_eval"] = True
if get("top.checkpoint_path"):
if finetuning_type in PEFT_METHODS: # list
args["adapter_name_or_path"] = ",".join(
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
)
else: # str
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
return args
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
@@ -296,9 +295,7 @@ class Runner:
self.running = True
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
lang = get("top.lang")
model_name = get("top.model_name")
finetuning_type = get("top.finetuning_type")
lang, model_name, finetuning_type = get("top.lang"), get("top.model_name"), get("top.finetuning_type")
output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
output_path = get_save_dir(model_name, finetuning_type, output_dir)
@@ -356,7 +353,7 @@ class Runner:
config_dict: Dict[str, Any] = {}
lang = data[self.manager.get_elem_by_id("top.lang")]
config_path = data[self.manager.get_elem_by_id("train.config_path")]
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"]
for elem, value in data.items():
elem_id = self.manager.get_id_by_elem(elem)
if elem_id not in skip_ids: