support rank0 logger

Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent ceb701c2d4
commit 093eda2ad6
42 changed files with 316 additions and 252 deletions

View File

@@ -19,7 +19,7 @@ from typing import Any, Dict, List
from transformers.trainer import TRAINER_STATE_NAME
from .logging import get_logger
from . import logging
from .packages import is_matplotlib_available
@@ -28,7 +28,7 @@ if is_matplotlib_available():
import matplotlib.pyplot as plt
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def smooth(scalars: List[float]) -> List[float]:
@@ -86,7 +86,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
metrics.append(data["log_history"][i][key])
if len(metrics) == 0:
logger.warning(f"No metric {key} to plot.")
logger.warning_rank0(f"No metric {key} to plot.")
continue
plt.figure()