remove PeftTrainer
Former-commit-id: cc0cff3e991f194732d278e627648e528118a719
This commit is contained in:
@@ -11,7 +11,6 @@ from peft import (
|
||||
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import load_trainable_params
|
||||
from llmtuner.tuner.core.utils import find_all_linear_modules
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -53,9 +52,6 @@ def init_adapter(
|
||||
else:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
latest_checkpoint = None
|
||||
|
||||
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
check_min_version("4.29.1")
|
||||
check_min_version("4.30.0")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft==0.4.0", "To fix: pip install peft==0.4.0")
|
||||
@@ -78,7 +78,7 @@ def load_model_and_tokenizer(
|
||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
model_to_load = model_args.model_name_or_path
|
||||
@@ -197,6 +197,7 @@ def load_model_and_tokenizer(
|
||||
# Prepare model with valuehead for RLHF
|
||||
if stage == "rm" or stage == "ppo":
|
||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
model._keys_to_ignore_on_save = None
|
||||
reset_logging()
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
|
||||
from transformers.modeling_utils import PreTrainedModel, unwrap_model
|
||||
from peft import PeftModel
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PeftModelMixin:
|
||||
r"""
|
||||
Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None: # for type checking
|
||||
self.model: PreTrainedModel = None
|
||||
self.tokenizer: "PreTrainedTokenizer" = None
|
||||
self.args: "Seq2SeqTrainingArguments" = None
|
||||
self.finetuning_args: "FinetuningArguments" = None
|
||||
self.state: "TrainerState" = None
|
||||
raise AssertionError("Mixin should not be initialized.")
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||
r"""
|
||||
Saves trainable parameters as model checkpoint.
|
||||
|
||||
This function will only be executed at the process zero.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||
model = self.model
|
||||
model_unwrapped = unwrap_model(model)
|
||||
|
||||
if isinstance(model_unwrapped, PreTrainedModelWrapper):
|
||||
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.7.1/trl/models/modeling_value_head.py#L200
|
||||
model_state_dict = state_dict or model.state_dict()
|
||||
v_head_state_dict = {
|
||||
name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
|
||||
for name in model_state_dict.keys() if name.startswith("v_head.")
|
||||
}
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
model = model_unwrapped.pretrained_model
|
||||
model_unwrapped = unwrap_model(model)
|
||||
|
||||
state_dict = state_dict or get_state_dict(model)
|
||||
if not isinstance(model, (PeftModel, PreTrainedModel)):
|
||||
if isinstance(model_unwrapped, (PeftModel, PreTrainedModel)):
|
||||
model_unwrapped.config.use_cache = True
|
||||
model_unwrapped.save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
model_unwrapped.config.use_cache = False
|
||||
else:
|
||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
model.config.use_cache = True
|
||||
model.save_pretrained(
|
||||
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
|
||||
try:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
except:
|
||||
logger.warning("Cannot save tokenizer, copy the files manually.")
|
||||
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||
f.write(self.args.to_json_string() + "\n")
|
||||
|
||||
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
||||
|
||||
def _load_best_model(self):
|
||||
r"""
|
||||
Loads trainable parameters from model checkpoint.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||
model = unwrap_model(self.model)
|
||||
|
||||
if isinstance(model, PreTrainedModelWrapper):
|
||||
model.v_head.load_state_dict(torch.load(
|
||||
os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
|
||||
))
|
||||
model = model.pretrained_model
|
||||
|
||||
if isinstance(model, PeftModel):
|
||||
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
||||
else: # freeze/full-tuning
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
|
||||
|
||||
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||
Seq2SeqTrainer.__init__(self, **kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
@@ -6,18 +6,16 @@ from trl import DPOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.tuner.core.trainer import PeftModelMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
beta: float,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
@@ -28,12 +26,11 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||
if ref_model is not None:
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
self.ref_model = ref_model
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
self.padding_value = 0
|
||||
self.beta = finetuning_args.dpo_beta
|
||||
self.beta = beta
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
|
||||
@@ -10,7 +10,7 @@ from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
||||
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
|
||||
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
@@ -37,10 +37,10 @@ def run_dpo(
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = DPOPeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
|
||||
trainer = CustomDPOTrainer(
|
||||
beta=finetuning_args.dpo_beta,
|
||||
model=model,
|
||||
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
|
||||
@@ -4,27 +4,25 @@ import torch
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import GenerationConfig, TrainerState, TrainerControl
|
||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
|
||||
from trl import PPOTrainer
|
||||
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
|
||||
from llmtuner.hparams import GeneratingArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
@@ -32,9 +30,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: List["LogCallback"],
|
||||
callbacks: List["TrainerCallback"],
|
||||
compute_dtype: torch.dtype,
|
||||
**kwargs
|
||||
):
|
||||
@@ -43,9 +40,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
|
||||
|
||||
self.args = training_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.generating_args = generating_args
|
||||
self.log_callback = callbacks[0]
|
||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||
self.compute_dtype = compute_dtype
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
@@ -147,7 +143,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
|
||||
self.log_callback.on_train_end(self.args, self.state, self.control)
|
||||
self.log_callback.on_train_end(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_inputs(
|
||||
@@ -296,3 +294,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
"""
|
||||
if self.args.should_save:
|
||||
self._save(output_dir)
|
||||
self.save_callback.on_save(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
|
||||
@@ -8,9 +8,10 @@ from transformers import DataCollatorWithPadding
|
||||
from transformers.optimization import get_scheduler
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
||||
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
@@ -61,11 +62,10 @@ def run_ppo(
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
ppo_trainer = PPOPeftTrainer(
|
||||
ppo_trainer = CustomPPOTrainer(
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
callbacks=callbacks,
|
||||
callbacks=callbacks + [SavePeftModelCallback()],
|
||||
compute_dtype=model_args.compute_dtype,
|
||||
config=ppo_config,
|
||||
model=model,
|
||||
|
||||
@@ -2,12 +2,11 @@
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
from transformers import DataCollatorForLanguageModeling, Trainer
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
@@ -27,8 +26,7 @@ def run_pt(
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
@@ -2,9 +2,9 @@ import os
|
||||
import json
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from transformers import Trainer
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PairwisePeftTrainer(PeftTrainer):
|
||||
class PairwiseTrainer(Trainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute pairwise loss.
|
||||
"""
|
||||
|
||||
@@ -5,11 +5,12 @@ from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
|
||||
from llmtuner.tuner.rm.trainer import PairwiseTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
@@ -33,13 +34,12 @@ def run_rm(
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PairwisePeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
trainer = PairwiseTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
callbacks=callbacks + [SavePeftModelCallback()],
|
||||
compute_metrics=compute_accuracy,
|
||||
**split_dataset(dataset, data_args, training_args)
|
||||
)
|
||||
|
||||
@@ -4,10 +4,10 @@ import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from transformers import Seq2SeqTrainer
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||
"""
|
||||
|
||||
@@ -9,7 +9,7 @@ from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
||||
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
@@ -45,8 +45,7 @@ def run_sft(
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqPeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
Reference in New Issue
Block a user