remove PeftTrainer

Former-commit-id: cc0cff3e991f194732d278e627648e528118a719
This commit is contained in:
hiyouga
2023-09-10 22:23:23 +08:00
parent 332d7bbd56
commit a09a7b650d
17 changed files with 75 additions and 259 deletions

View File

@@ -5,7 +5,9 @@ from typing import TYPE_CHECKING
from datetime import timedelta
from transformers import TrainerCallback
from transformers.trainer_utils import has_length
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
from transformers.training_args import TrainingArguments
from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger
@@ -17,6 +19,24 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
class SavePeftModelCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
"""
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
return control
def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
r"""
Event called at the end of training.
"""
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
return control
class LogCallback(TrainerCallback):
def __init__(self, runner=None):