alter rewards data type
Former-commit-id: 3eb7eb2d37525da50fe401ab7c59532e6e1ef984
This commit is contained in:
@@ -5,7 +5,6 @@ import torch
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
@@ -143,7 +142,7 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
||||
|
||||
|
||||
def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
|
||||
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
|
||||
"""
|
||||
EMA implementation according to TensorBoard.
|
||||
"""
|
||||
@@ -156,9 +155,10 @@ def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
|
||||
return smoothed
|
||||
|
||||
|
||||
def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]] = ["loss"]) -> None:
|
||||
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
||||
import matplotlib.pyplot as plt
|
||||
data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r"))
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for key in keys:
|
||||
steps, metrics = [], []
|
||||
@@ -174,9 +174,9 @@ def plot_loss(training_args: Seq2SeqTrainingArguments, keys: Optional[List[str]]
|
||||
plt.figure()
|
||||
plt.plot(steps, metrics, alpha=0.4, label="original")
|
||||
plt.plot(steps, smooth(metrics), label="smoothed")
|
||||
plt.title("training {} of {}".format(key, training_args.output_dir))
|
||||
plt.title("training {} of {}".format(key, save_dictionary))
|
||||
plt.xlabel("step")
|
||||
plt.ylabel(key)
|
||||
plt.legend()
|
||||
plt.savefig(os.path.join(training_args.output_dir, "training_{}.png".format(key)), format="png", dpi=100)
|
||||
print("Figure saved:", os.path.join(training_args.output_dir, "training_{}.png".format(key)))
|
||||
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
|
||||
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
|
||||
|
||||
Reference in New Issue
Block a user