update web UI, support rm predict #210

Former-commit-id: 92cc6b655dc91b94d5bf9d8618c3b57d5cf94333
This commit is contained in:
hiyouga
2023-07-21 13:27:27 +08:00
parent c4e9694c6e
commit 0f7cdac207
13 changed files with 192 additions and 27 deletions

View File

@@ -54,7 +54,7 @@ def get_train_args(
assert not (training_args.do_train and training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True while training."
assert (not training_args.do_predict) or training_args.predict_with_generate, \
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
"Please enable `predict_with_generate` to save model predictions."
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \

View File

@@ -4,7 +4,8 @@ from typing import Dict, Optional
from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from peft import PeftModel
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger
@@ -49,9 +50,9 @@ class PeftTrainer(Seq2SeqTrainer):
else:
backbone_model = model
if self.finetuning_args.finetuning_type == "lora":
if isinstance(backbone_model, PeftModel): # LoRA tuning
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
else: # freeze/full tuning
elif isinstance(backbone_model, PreTrainedModel): # freeze/full tuning
backbone_model.config.use_cache = True
backbone_model.save_pretrained(
output_dir,
@@ -61,6 +62,8 @@ class PeftTrainer(Seq2SeqTrainer):
backbone_model.config.use_cache = False
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
else:
logger.warning("No model to save.")
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
@@ -77,8 +80,8 @@ class PeftTrainer(Seq2SeqTrainer):
model = unwrap_model(self.model)
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
if self.finetuning_args.finetuning_type == "lora":
backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
if isinstance(backbone_model, PeftModel):
backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter)
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
model.v_head.load_state_dict({
"summary.weight": getattr(model, "reward_head_weight"),

View File

@@ -1,10 +1,17 @@
import os
import json
import torch
from typing import Dict, List, Optional, Tuple, Union
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer
logger = get_logger(__name__)
class PairwisePeftTrainer(PeftTrainer):
r"""
Inherits PeftTrainer to compute pairwise loss.
@@ -36,3 +43,26 @@ class PairwisePeftTrainer(PeftTrainer):
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
def save_predictions(
self,
predict_results: PredictionOutput
) -> None:
r"""
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
acc_scores, rej_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for acc_score, rej_score in zip(acc_scores, rej_scores):
res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)}))
writer.write("\n".join(res))

View File

@@ -56,3 +56,10 @@ def run_rm(
metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict")
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)