remove PeftTrainer

Former-commit-id: cc0cff3e991f194732d278e627648e528118a719
This commit is contained in:
hiyouga
2023-09-10 22:23:23 +08:00
parent 332d7bbd56
commit a09a7b650d
17 changed files with 75 additions and 259 deletions

View File

@@ -4,27 +4,25 @@ import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from transformers import GenerationConfig, TrainerState, TrainerControl
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from trl import PPOTrainer
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.callbacks import LogCallback
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
from llmtuner.hparams import GeneratingArguments
logger = get_logger(__name__)
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
class CustomPPOTrainer(PPOTrainer, Trainer):
r"""
Inherits PPOTrainer.
"""
@@ -32,9 +30,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
def __init__(
self,
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["LogCallback"],
callbacks: List["TrainerCallback"],
compute_dtype: torch.dtype,
**kwargs
):
@@ -43,9 +40,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
self.args = training_args
self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.log_callback = callbacks[0]
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
self.compute_dtype = compute_dtype
self.state = TrainerState()
self.control = TrainerControl()
@@ -147,7 +143,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
dataiter = iter(self.dataloader)
steps_trained = 0
self.log_callback.on_train_end(self.args, self.state, self.control)
self.log_callback.on_train_end(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)
@torch.no_grad()
def get_inputs(
@@ -296,3 +294,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
"""
if self.args.should_save:
self._save(output_dir)
self.save_callback.on_save(
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
)

View File

@@ -8,9 +8,10 @@ from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
@@ -61,11 +62,10 @@ def run_ppo(
)
# Initialize our Trainer
ppo_trainer = PPOPeftTrainer(
ppo_trainer = CustomPPOTrainer(
training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks,
callbacks=callbacks + [SavePeftModelCallback()],
compute_dtype=model_args.compute_dtype,
config=ppo_config,
model=model,