refactor adapter hparam

Former-commit-id: f82aece9ebd6df83a7a005cc7cbbcec07fa6e14d
This commit is contained in:
hiyouga
2023-12-15 20:53:11 +08:00
parent 27ef5b1aa7
commit f902b0d420
21 changed files with 302 additions and 311 deletions

View File

@@ -86,19 +86,19 @@ class Runner:
get = lambda name: data[self.manager.get_elem_by_name(name)]
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
if get("top.adapter_path"):
adapter_name_or_path = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")])
else:
checkpoint_dir = None
adapter_name_or_path = None
args = dict(
stage=TRAINING_STAGES[get("train.training_stage")],
model_name_or_path=get("top.model_path"),
do_train=True,
model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
cache_dir=user_config.get("cache_dir", None),
checkpoint_dir=checkpoint_dir,
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
@@ -125,17 +125,14 @@ class Runner:
lora_dropout=get("train.lora_dropout"),
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
additional_target=get("train.additional_target") if get("train.additional_target") else None,
resume_lora_training=get("train.resume_lora_training"),
create_new_adapter=get("train.create_new_adapter"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
)
args[get("train.compute_type")] = True
args["disable_tqdm"] = True
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
args["resume_lora_training"] = (args["quantization_bit"] is not None)
if args["quantization_bit"] is not None:
args["upcast_layernorm"] = True
args["create_new_adapter"] = (args["quantization_bit"] is None)
if args["stage"] == "ppo":
args["reward_model"] = get_save_dir(
@@ -158,20 +155,19 @@ class Runner:
get = lambda name: data[self.manager.get_elem_by_name(name)]
user_config = load_config()
if get("top.checkpoints"):
checkpoint_dir = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
])
if get("top.adapter_path"):
adapter_name_or_path = ",".join([
get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
for adapter in get("top.adapter_path")])
else:
checkpoint_dir = None
adapter_name_or_path = None
args = dict(
stage="sft",
model_name_or_path=get("top.model_path"),
do_eval=True,
predict_with_generate=True,
model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
cache_dir=user_config.get("cache_dir", None),
checkpoint_dir=checkpoint_dir,
finetuning_type=get("top.finetuning_type"),
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
template=get("top.template"),
@@ -183,6 +179,7 @@ class Runner:
cutoff_len=get("eval.cutoff_len"),
max_samples=int(get("eval.max_samples")),
per_device_eval_batch_size=get("eval.batch_size"),
predict_with_generate=True,
max_new_tokens=get("eval.max_new_tokens"),
top_p=get("eval.top_p"),
temperature=get("eval.temperature"),