modity code structure
Former-commit-id: 0682ed357210897e0b67c4a6eb31a94b3eb929f1
This commit is contained in:
1
src/llmtuner/tuner/rm/__init__.py
Normal file
1
src/llmtuner/tuner/rm/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from llmtuner.tuner.rm.workflow import run_rm
|
||||
19
src/llmtuner/tuner/rm/collator.py
Normal file
19
src/llmtuner/tuner/rm/collator.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
from typing import Any, Dict, Sequence
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
|
||||
return super().__call__(features)
|
||||
7
src/llmtuner/tuner/rm/metric.py
Normal file
7
src/llmtuner/tuner/rm/metric.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import numpy as np
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
|
||||
|
||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
preds, _ = eval_preds
|
||||
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
|
||||
38
src/llmtuner/tuner/rm/trainer.py
Normal file
38
src/llmtuner/tuner/rm/trainer.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
|
||||
class PairwisePeftTrainer(PeftTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute pairwise loss.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
return_outputs: Optional[bool] = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
|
||||
We use score on the EOS token to represent reward of the whole sentence.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
|
||||
Note that the first element will be removed from the output tuple.
|
||||
|
||||
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
|
||||
"""
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
_, _, values = model(**inputs)
|
||||
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
||||
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
||||
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
|
||||
66
src/llmtuner/tuner/rm/workflow.py
Normal file
66
src/llmtuner/tuner/rm/workflow.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Inspired by:
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
||||
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
|
||||
|
||||
|
||||
def run_rm(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
|
||||
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
if data_args.dev_ratio > 1e-6:
|
||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
trainer_kwargs = {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
trainer_kwargs = {"eval_dataset": dataset}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PairwisePeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=[LogCallback()],
|
||||
compute_metrics=compute_accuracy,
|
||||
**trainer_kwargs
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
Reference in New Issue
Block a user