mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 08:33:38 +00:00
remove PeftTrainer
Former-commit-id: cc0cff3e991f194732d278e627648e528118a719
This commit is contained in:
@@ -4,10 +4,10 @@ import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from transformers import Seq2SeqTrainer
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user