support DPO training (2305.18290)
Former-commit-id: 6d98de148e4af63a7028dfaeb6cf86eb56a4488f
This commit is contained in:
@@ -39,7 +39,7 @@ def init_adapter(
|
||||
if finetuning_args.finetuning_type == "none" and is_trainable:
|
||||
raise ValueError("You cannot use finetuning_type=none while training.")
|
||||
|
||||
if finetuning_args.finetuning_type == "full":
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
model = model.float()
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ check_min_version("4.29.1")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
|
||||
require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
@@ -52,9 +52,6 @@ def load_model_and_tokenizer(
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
|
||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||
"RM and PPO training can only be performed with the LoRA method."
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
@@ -132,8 +129,6 @@ def load_model_and_tokenizer(
|
||||
})
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
assert is_trainable, "PPO stage cannot be performed at evaluation."
|
||||
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
||||
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
||||
|
||||
@@ -19,7 +19,7 @@ from llmtuner.hparams import (
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
|
||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
@@ -32,26 +32,53 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None)
|
||||
|
||||
def parse_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def parse_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args)
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
]:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
@@ -68,7 +95,7 @@ def get_train_args(
|
||||
data_args.init_for_training()
|
||||
|
||||
if general_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
@@ -76,6 +103,15 @@ def get_train_args(
|
||||
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("RM and PPO training can only be performed with the LoRA method.")
|
||||
|
||||
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
||||
raise ValueError("PPO and DPO stage can only be performed at training.")
|
||||
|
||||
if general_args.stage == "ppo" and model_args.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
@@ -133,12 +169,17 @@ def get_train_args(
|
||||
# Set seed before initializing model.
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, general_args
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args, general_args
|
||||
|
||||
|
||||
def get_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
|
||||
@@ -13,26 +13,25 @@ from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PeftTrainer(Seq2SeqTrainer):
|
||||
class PeftModelMixin:
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self._remove_log()
|
||||
|
||||
def _remove_log(self):
|
||||
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
||||
logger.warning("Previous log file in this folder will be deleted.")
|
||||
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
||||
def __init__(self) -> None: # for type checking
|
||||
self.model: PreTrainedModel = None
|
||||
self.tokenizer: "PreTrainedTokenizer" = None
|
||||
self.args: "Seq2SeqTrainingArguments" = None
|
||||
self.finetuning_args: "FinetuningArguments" = None
|
||||
self.state: "TrainerState" = None
|
||||
raise AssertionError("Mixin should not be initialized.")
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||
r"""
|
||||
@@ -96,3 +95,13 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
||||
else: # freeze/full-tuning
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
|
||||
|
||||
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||
Seq2SeqTrainer.__init__(self, **kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
|
||||
1
src/llmtuner/tuner/dpo/__init__.py
Normal file
1
src/llmtuner/tuner/dpo/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from llmtuner.tuner.dpo.workflow import run_dpo
|
||||
51
src/llmtuner/tuner/dpo/collator.py
Normal file
51
src/llmtuner/tuner/dpo/collator.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
|
||||
padded_labels = []
|
||||
for feature, (prompt_len, answer_len) in zip(batch, positions):
|
||||
if self.tokenizer.padding_side == "left":
|
||||
start, end = feature.size(0) - answer_len, feature.size(0)
|
||||
else:
|
||||
start, end = prompt_len, answer_len
|
||||
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
||||
padded_tensor[start:end] = feature[start:end]
|
||||
padded_labels.append(padded_tensor)
|
||||
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
concatenated_features = []
|
||||
label_positions = []
|
||||
for key in ("chosen_ids", "rejected_ids"):
|
||||
for feature in features:
|
||||
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
||||
concatenated_features.append({
|
||||
"input_ids": feature["prompt_ids"] + feature[key],
|
||||
"attention_mask": [1] * (prompt_len + answer_len)
|
||||
})
|
||||
label_positions.append((prompt_len, answer_len))
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
concatenated_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||
return batch
|
||||
75
src/llmtuner/tuner/dpo/trainer.py
Normal file
75
src/llmtuner/tuner/dpo/trainer.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from peft import PeftModel
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
from transformers import Trainer
|
||||
from trl import DPOTrainer
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.tuner.core.trainer import PeftModelMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
**kwargs
|
||||
):
|
||||
self.finetuning_args = finetuning_args
|
||||
self.generating_args = generating_args
|
||||
self.ref_model = ref_model
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
self.padding_value = 0
|
||||
self.beta = finetuning_args.dpo_beta
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, **kwargs)
|
||||
if ref_model is not None:
|
||||
if hasattr(self, "accelerator"):
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
else:
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
model: Optional[torch.nn.Module] = None,
|
||||
batch: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||
if not torch.is_grad_enabled():
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
|
||||
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
|
||||
with unwrapped_model.disable_adapter():
|
||||
all_logits: torch.Tensor = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
else:
|
||||
all_logits: torch.Tensor = model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
|
||||
if not torch.is_grad_enabled():
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
|
||||
all_logps = self._get_batch_logps(
|
||||
all_logits,
|
||||
batch["labels"],
|
||||
average_log_prob=False
|
||||
)
|
||||
batch_size = batch["input_ids"].size(0) // 2
|
||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
||||
59
src/llmtuner/tuner/dpo/workflow.py
Normal file
59
src/llmtuner/tuner/dpo/workflow.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
||||
|
||||
from copy import deepcopy
|
||||
from peft import PeftModel
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
||||
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
def run_dpo(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = DPODataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = DPOPeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
ref_model=ref_model,
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**split_dataset(dataset, data_args, training_args)
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
@@ -10,7 +10,7 @@ from trl import PPOTrainer
|
||||
from trl.core import LengthSampler
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor, get_stopping_criteria
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||
|
||||
@@ -18,7 +18,7 @@ if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -33,16 +33,17 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
self,
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: List["LogCallback"],
|
||||
**kwargs
|
||||
):
|
||||
PPOTrainer.__init__(self, **kwargs)
|
||||
self.args = training_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.generating_args = generating_args
|
||||
self.log_callback = callbacks[0]
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
self._remove_log()
|
||||
|
||||
def ppo_train(self, max_target_length: int) -> None:
|
||||
r"""
|
||||
@@ -72,14 +73,10 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = {
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": self.tokenizer.pad_token_id,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
gen_kwargs = self.generating_args.to_dict()
|
||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||
gen_kwargs["stopping_criteria"] = get_stopping_criteria(self.tokenizer.additional_special_tokens_ids)
|
||||
|
||||
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
# Inspired by:
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
||||
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
from trl import PPOConfig
|
||||
from torch.optim import AdamW
|
||||
from typing import Optional, List
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from transformers.optimization import get_scheduler
|
||||
|
||||
@@ -16,7 +14,7 @@ from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
def run_ppo(
|
||||
@@ -24,6 +22,7 @@ def run_ppo(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
@@ -42,8 +41,9 @@ def run_ppo(
|
||||
)
|
||||
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
total_train_batch_size = \
|
||||
total_train_batch_size = (
|
||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||
)
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||
lr_scheduler = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
@@ -56,6 +56,7 @@ def run_ppo(
|
||||
ppo_trainer = PPOPeftTrainer(
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
callbacks=callbacks,
|
||||
config=ppo_config,
|
||||
model=model,
|
||||
@@ -67,8 +68,10 @@ def run_ppo(
|
||||
lr_scheduler=lr_scheduler
|
||||
)
|
||||
|
||||
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
||||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be after save_model
|
||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
||||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
@@ -25,10 +24,7 @@ def run_pt(
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PeftTrainer(
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Sequence
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
@@ -16,7 +18,10 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
features = [
|
||||
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])}
|
||||
for key in ("accept_ids", "reject_ids") for feature in features
|
||||
{
|
||||
"input_ids": feature["prompt_ids"] + feature[key],
|
||||
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
|
||||
}
|
||||
for key in ("chosen_ids", "rejected_ids") for feature in features
|
||||
]
|
||||
return super().__call__(features)
|
||||
|
||||
@@ -79,7 +79,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
|
||||
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
||||
return padded_tensor.contiguous()
|
||||
return padded_tensor.contiguous() # in contiguous memory
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
|
||||
@@ -5,7 +5,7 @@ from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.misc import get_logits_processor, get_stopping_criteria
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||
@@ -13,7 +13,7 @@ from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
def run_sft(
|
||||
@@ -21,6 +21,7 @@ def run_sft(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
@@ -50,13 +51,9 @@ def run_sft(
|
||||
)
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
"top_p": 0.7,
|
||||
"max_new_tokens": data_args.max_target_length + 1,
|
||||
"temperature": 0.95,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
gen_kwargs = generating_args.to_dict()
|
||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||
gen_kwargs["stopping_criteria"] = get_stopping_criteria(tokenizer.additional_special_tokens_ids)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
|
||||
@@ -1,35 +1,47 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner.pt import run_pt
|
||||
from llmtuner.tuner.sft import run_sft
|
||||
from llmtuner.tuner.rm import run_rm
|
||||
from llmtuner.tuner.ppo import run_ppo
|
||||
from llmtuner.tuner.dpo import run_dpo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||
model_args, data_args, training_args, finetuning_args, general_args = get_train_args(args)
|
||||
callbacks = [LogCallback()] if callbacks is None else callbacks
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
|
||||
callbacks = [LogCallback()] if callbacks is None else callbacks + [LogCallback()]
|
||||
|
||||
if general_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif general_args.stage == "sft":
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif general_args.stage == "rm":
|
||||
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif general_args.stage == "ppo":
|
||||
run_ppo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif general_args.stage == "dpo":
|
||||
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
raise ValueError("Unknown task.")
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
||||
model_args, _, training_args, finetuning_args, _ = get_train_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
try:
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
except:
|
||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user