support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -19,7 +19,7 @@ from typing import Any, Dict, List
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
from .logging import get_logger
|
||||
from . import logging
|
||||
from .packages import is_matplotlib_available
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ if is_matplotlib_available():
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def smooth(scalars: List[float]) -> List[float]:
|
||||
@@ -86,7 +86,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
metrics.append(data["log_history"][i][key])
|
||||
|
||||
if len(metrics) == 0:
|
||||
logger.warning(f"No metric {key} to plot.")
|
||||
logger.warning_rank0(f"No metric {key} to plot.")
|
||||
continue
|
||||
|
||||
plt.figure()
|
||||
|
||||
Reference in New Issue
Block a user