support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -25,8 +25,8 @@ import torch
|
||||
from transformers import Seq2SeqTrainer
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
@@ -40,7 +40,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
@@ -142,7 +142,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
return
|
||||
|
||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
||||
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
|
||||
|
||||
labels = np.where(
|
||||
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
|
||||
|
||||
Reference in New Issue
Block a user