refactor dataset_attr, add eos in pt, fix #757
Former-commit-id: 0feec9a830b917b36686b61938a66e842eccf930
This commit is contained in:
@@ -36,7 +36,7 @@ check_min_version("4.29.1")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||
require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
|
||||
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
|
||||
@@ -5,6 +5,7 @@ import datasets
|
||||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
@@ -110,6 +111,11 @@ def get_train_args(
|
||||
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
||||
raise ValueError("PPO and DPO stages can only be performed at training.")
|
||||
|
||||
if general_args.stage in ["rm", "dpo"]:
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
if not dataset_attr.ranking:
|
||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
||||
|
||||
if general_args.stage == "ppo" and model_args.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
@@ -166,6 +172,7 @@ def get_train_args(
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
|
||||
@@ -186,18 +193,6 @@ def get_train_args(
|
||||
else:
|
||||
model_args.compute_dtype = torch.float16
|
||||
|
||||
# transfer training stage to dataset stage
|
||||
dataset_stage = general_args.stage
|
||||
if general_args.stage == "ppo":
|
||||
dataset_stage = "sft"
|
||||
elif general_args.stage == "dpo":
|
||||
dataset_stage = "rm"
|
||||
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
if dataset_attr.stage and dataset_attr.stage != dataset_stage:
|
||||
raise ValueError("Dataset {} is not supported for the stage {}"
|
||||
.format(dataset_attr.dataset_name, general_args.stage))
|
||||
|
||||
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
# Log on each process the small summary:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from peft import PeftModel
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
from transformers import BatchEncoding, Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.tuner.core.trainer import PeftModelMixin
|
||||
@@ -18,9 +18,16 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
**kwargs
|
||||
):
|
||||
if disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if ref_model is not None:
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
self.ref_model = ref_model
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
@@ -29,12 +36,16 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
self.beta = finetuning_args.dpo_beta
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, **kwargs)
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
if ref_model is not None:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = self.accelerator._prepare_deepspeed(self.ref_model)
|
||||
self.ref_model.eval()
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
@@ -42,27 +53,12 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
batch: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
if not torch.is_grad_enabled():
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
|
||||
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
|
||||
with unwrapped_model.disable_adapter():
|
||||
all_logits = self.model(
|
||||
input_ids=batch_copied["input_ids"],
|
||||
attention_mask=batch_copied["attention_mask"],
|
||||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
else:
|
||||
all_logits = model(
|
||||
input_ids=batch_copied["input_ids"],
|
||||
attention_mask=batch_copied["attention_mask"],
|
||||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
|
||||
if not torch.is_grad_enabled():
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
all_logits = model(
|
||||
input_ids=batch_copied["input_ids"],
|
||||
attention_mask=batch_copied["attention_mask"],
|
||||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
|
||||
all_logps = self._get_batch_logps(
|
||||
all_logits,
|
||||
|
||||
@@ -202,7 +202,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
queries: torch.Tensor,
|
||||
responses: torch.Tensor,
|
||||
model_inputs: dict,
|
||||
return_logits: Optional[bool] = False
|
||||
return_logits: Optional[bool] = False,
|
||||
response_masks: Optional[torch.Tensor] = None
|
||||
):
|
||||
r"""
|
||||
Calculates model outputs in multiple batches.
|
||||
@@ -220,6 +221,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
|
||||
query_batch = queries[i * fbs : (i + 1) * fbs]
|
||||
response_batch = responses[i * fbs : (i + 1) * fbs]
|
||||
if response_masks is not None:
|
||||
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
|
||||
input_ids = input_kwargs["input_ids"]
|
||||
attention_mask = input_kwargs["attention_mask"]
|
||||
|
||||
@@ -239,8 +242,15 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
start += attention_mask[j, :].nonzero()[0]
|
||||
end = start + len(response_batch[j])
|
||||
|
||||
if response_masks is not None:
|
||||
response_masks_batch = torch.cat(
|
||||
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
|
||||
)[1:]
|
||||
|
||||
masks[j, :start] = 0
|
||||
masks[j, end:] = 0
|
||||
if response_masks is not None:
|
||||
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
|
||||
|
||||
if return_logits:
|
||||
all_logits.append(logits)
|
||||
|
||||
@@ -44,7 +44,6 @@ def run_ppo(
|
||||
)
|
||||
|
||||
if finetuning_args.ppo_score_norm:
|
||||
require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git")
|
||||
ppo_config.use_score_scaling = True
|
||||
ppo_config.use_score_norm = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user