better llamaboard
* easily resume from checkpoint * support full and freeze checkpoints * faster ui Former-commit-id: 84cfb2452cc86b037ccddee6e833f8eb7c129fa4
This commit is contained in:
@@ -7,12 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||
import psutil
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
|
||||
from ..extras.constants import TRAINING_STAGES
|
||||
from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_args, load_config, save_args
|
||||
from .common import DEFAULT_CACHE_DIR, get_module, get_save_dir, load_config
|
||||
from .locales import ALERTS
|
||||
from .utils import gen_cmd, get_eval_results, get_trainer_info, save_cmd
|
||||
from .utils import gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
||||
|
||||
|
||||
if is_gradio_available():
|
||||
@@ -85,26 +85,16 @@ class Runner:
|
||||
|
||||
def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")
|
||||
]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
args = dict(
|
||||
stage=TRAINING_STAGES[get("train.training_stage")],
|
||||
do_train=True,
|
||||
model_name_or_path=get("top.model_path"),
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
preprocessing_num_workers=16,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
@@ -134,13 +124,23 @@ class Runner:
|
||||
report_to="all" if get("train.report_to") else "none",
|
||||
use_galore=get("train.use_galore"),
|
||||
use_badam=get("train.use_badam"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
||||
output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")),
|
||||
fp16=(get("train.compute_type") == "fp16"),
|
||||
bf16=(get("train.compute_type") == "bf16"),
|
||||
pure_bf16=(get("train.compute_type") == "pure_bf16"),
|
||||
plot_loss=True,
|
||||
ddp_timeout=180000000,
|
||||
)
|
||||
|
||||
# checkpoints
|
||||
if get("top.checkpoint_path"):
|
||||
if finetuning_type in PEFT_METHODS: # list
|
||||
args["adapter_name_or_path"] = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
|
||||
)
|
||||
else: # str
|
||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
||||
|
||||
# freeze config
|
||||
if args["finetuning_type"] == "freeze":
|
||||
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
|
||||
@@ -156,7 +156,7 @@ class Runner:
|
||||
args["create_new_adapter"] = get("train.create_new_adapter")
|
||||
args["use_rslora"] = get("train.use_rslora")
|
||||
args["use_dora"] = get("train.use_dora")
|
||||
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
|
||||
args["lora_target"] = get("train.lora_target") or get_module(model_name)
|
||||
args["additional_target"] = get("train.additional_target") or None
|
||||
|
||||
if args["use_llama_pro"]:
|
||||
@@ -164,13 +164,14 @@ class Runner:
|
||||
|
||||
# rlhf config
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("train.reward_model")
|
||||
]
|
||||
)
|
||||
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
||||
if finetuning_type in PEFT_METHODS:
|
||||
args["reward_model"] = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("train.reward_model")]
|
||||
)
|
||||
else:
|
||||
args["reward_model"] = get_save_dir(model_name, finetuning_type, get("train.reward_model"))
|
||||
|
||||
args["reward_model_type"] = "lora" if finetuning_type == "lora" else "full"
|
||||
args["ppo_score_norm"] = get("train.ppo_score_norm")
|
||||
args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
|
||||
args["top_k"] = 0
|
||||
@@ -211,25 +212,15 @@ class Runner:
|
||||
|
||||
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.adapter_path"):
|
||||
adapter_name_or_path = ",".join(
|
||||
[
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
|
||||
for adapter in get("top.adapter_path")
|
||||
]
|
||||
)
|
||||
else:
|
||||
adapter_name_or_path = None
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=get("top.model_path"),
|
||||
adapter_name_or_path=adapter_name_or_path,
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
preprocessing_num_workers=16,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
@@ -245,7 +236,7 @@ class Runner:
|
||||
max_new_tokens=get("eval.max_new_tokens"),
|
||||
top_p=get("eval.top_p"),
|
||||
temperature=get("eval.temperature"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir")),
|
||||
output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
|
||||
)
|
||||
|
||||
if get("eval.predict"):
|
||||
@@ -253,6 +244,14 @@ class Runner:
|
||||
else:
|
||||
args["do_eval"] = True
|
||||
|
||||
if get("top.checkpoint_path"):
|
||||
if finetuning_type in PEFT_METHODS: # list
|
||||
args["adapter_name_or_path"] = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
|
||||
)
|
||||
else: # str
|
||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
||||
|
||||
return args
|
||||
|
||||
def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
|
||||
@@ -296,9 +295,7 @@ class Runner:
|
||||
self.running = True
|
||||
|
||||
get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang = get("top.lang")
|
||||
model_name = get("top.model_name")
|
||||
finetuning_type = get("top.finetuning_type")
|
||||
lang, model_name, finetuning_type = get("top.lang"), get("top.model_name"), get("top.finetuning_type")
|
||||
output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
|
||||
output_path = get_save_dir(model_name, finetuning_type, output_dir)
|
||||
|
||||
@@ -356,7 +353,7 @@ class Runner:
|
||||
config_dict: Dict[str, Any] = {}
|
||||
lang = data[self.manager.get_elem_by_id("top.lang")]
|
||||
config_path = data[self.manager.get_elem_by_id("train.config_path")]
|
||||
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
|
||||
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"]
|
||||
for elem, value in data.items():
|
||||
elem_id = self.manager.get_id_by_elem(elem)
|
||||
if elem_id not in skip_ids:
|
||||
|
||||
Reference in New Issue
Block a user