[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
@@ -31,10 +31,8 @@ if is_matplotlib_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def smooth(scalars: List[float]) -> List[float]:
|
||||
r"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
def smooth(scalars: list[float]) -> list[float]:
|
||||
r"""EMA implementation according to TensorBoard."""
|
||||
if len(scalars) == 0:
|
||||
return []
|
||||
|
||||
@@ -48,10 +46,8 @@ def smooth(scalars: List[float]) -> List[float]:
|
||||
return smoothed
|
||||
|
||||
|
||||
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
|
||||
r"""
|
||||
Plots loss curves in LlamaBoard.
|
||||
"""
|
||||
def gen_loss_plot(trainer_log: list[dict[str, Any]]) -> "matplotlib.figure.Figure":
|
||||
r"""Plot loss curves in LlamaBoard."""
|
||||
plt.close("all")
|
||||
plt.switch_backend("agg")
|
||||
fig = plt.figure()
|
||||
@@ -70,10 +66,8 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
|
||||
return fig
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
r"""
|
||||
Plots loss curves and saves the image.
|
||||
"""
|
||||
def plot_loss(save_dictionary: str, keys: list[str] = ["loss"]) -> None:
|
||||
r"""Plot loss curves and saves the image."""
|
||||
plt.switch_backend("agg")
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
Reference in New Issue
Block a user