remove PeftTrainer
Former-commit-id: cc0cff3e991f194732d278e627648e528118a719
This commit is contained in:
@@ -8,9 +8,10 @@ from transformers import DataCollatorWithPadding
|
||||
from transformers.optimization import get_scheduler
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
||||
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
@@ -61,11 +62,10 @@ def run_ppo(
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
ppo_trainer = PPOPeftTrainer(
|
||||
ppo_trainer = CustomPPOTrainer(
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
callbacks=callbacks,
|
||||
callbacks=callbacks + [SavePeftModelCallback()],
|
||||
compute_dtype=model_args.compute_dtype,
|
||||
config=ppo_config,
|
||||
model=model,
|
||||
|
||||
Reference in New Issue
Block a user