refactor dataset_attr, add eos in pt, fix #757

Former-commit-id: 0feec9a830b917b36686b61938a66e842eccf930
This commit is contained in:
hiyouga
2023-09-01 19:00:45 +08:00
parent 93be211f80
commit e5b72c6a77
19 changed files with 108 additions and 126 deletions

View File

@@ -36,7 +36,7 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
def load_model_and_tokenizer(

View File

@@ -5,6 +5,7 @@ import datasets
import transformers
from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.utils.versions import require_version
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger
@@ -110,6 +111,11 @@ def get_train_args(
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
raise ValueError("PPO and DPO stages can only be performed at training.")
if general_args.stage in ["rm", "dpo"]:
for dataset_attr in data_args.dataset_list:
if not dataset_attr.ranking:
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
if general_args.stage == "ppo" and model_args.reward_model is None:
raise ValueError("Reward model is necessary for PPO training.")
@@ -166,6 +172,7 @@ def get_train_args(
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
@@ -186,18 +193,6 @@ def get_train_args(
else:
model_args.compute_dtype = torch.float16
# transfer training stage to dataset stage
dataset_stage = general_args.stage
if general_args.stage == "ppo":
dataset_stage = "sft"
elif general_args.stage == "dpo":
dataset_stage = "rm"
for dataset_attr in data_args.dataset_list:
if dataset_attr.stage and dataset_attr.stage != dataset_stage:
raise ValueError("Dataset {} is not supported for the stage {}"
.format(dataset_attr.dataset_name, general_args.stage))
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
# Log on each process the small summary:

View File

@@ -1,9 +1,9 @@
import torch
from collections import defaultdict
from peft import PeftModel
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer
from trl import DPOTrainer
from trl.trainer.utils import disable_dropout_in_model
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core.trainer import PeftModelMixin
@@ -18,9 +18,16 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
def __init__(
self,
finetuning_args: "FinetuningArguments",
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True,
**kwargs
):
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None:
disable_dropout_in_model(ref_model)
self.finetuning_args = finetuning_args
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
@@ -29,12 +36,16 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
self.beta = finetuning_args.dpo_beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs)
Trainer.__init__(self, model=model, **kwargs)
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
if ref_model is not None:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
if self.is_deepspeed_enabled:
self.ref_model = self.accelerator._prepare_deepspeed(self.ref_model)
self.ref_model.eval()
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_forward(
self,
@@ -42,27 +53,12 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
batch: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_disable()
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
with unwrapped_model.disable_adapter():
all_logits = self.model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
else:
all_logits = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_enable()
all_logits = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
all_logps = self._get_batch_logps(
all_logits,

View File

@@ -202,7 +202,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False
return_logits: Optional[bool] = False,
response_masks: Optional[torch.Tensor] = None
):
r"""
Calculates model outputs in multiple batches.
@@ -220,6 +221,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
@@ -239,8 +242,15 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits:
all_logits.append(logits)

View File

@@ -44,7 +44,6 @@ def run_ppo(
)
if finetuning_args.ppo_score_norm:
require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git")
ppo_config.use_score_scaling = True
ppo_config.use_score_norm = True