update web UI, support rm predict #210

Former-commit-id: 92cc6b655dc91b94d5bf9d8618c3b57d5cf94333
This commit is contained in:
hiyouga
2023-07-21 13:27:27 +08:00
parent c4e9694c6e
commit 0f7cdac207
13 changed files with 192 additions and 27 deletions

View File

@@ -54,7 +54,7 @@ def get_train_args(
assert not (training_args.do_train and training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True while training."
assert (not training_args.do_predict) or training_args.predict_with_generate, \
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
"Please enable `predict_with_generate` to save model predictions."
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \

View File

@@ -4,7 +4,8 @@ from typing import Dict, Optional
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from peft import PeftModel
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger
@@ -49,9 +50,9 @@ class PeftTrainer(Seq2SeqTrainer):
else:
backbone_model = model
if self.finetuning_args.finetuning_type == "lora":
if isinstance(backbone_model, PeftModel): # LoRA tuning
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
else: # freeze/full tuning
elif isinstance(backbone_model, PreTrainedModel): # freeze/full tuning
backbone_model.config.use_cache = True
backbone_model.save_pretrained(
output_dir,
@@ -61,6 +62,8 @@ class PeftTrainer(Seq2SeqTrainer):
backbone_model.config.use_cache = False
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
else:
logger.warning("No model to save.")
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
@@ -77,8 +80,8 @@ class PeftTrainer(Seq2SeqTrainer):
model = unwrap_model(self.model)
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
if self.finetuning_args.finetuning_type == "lora":
backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
if isinstance(backbone_model, PeftModel):
backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter)
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),