allow non-packing pretraining

Former-commit-id: 3fee5cc5a3db9ce874ad90f2500ec092d904bd4e
This commit is contained in:
hiyouga
2024-03-09 22:21:46 +08:00
parent c631799f5d
commit 4881f4e631
22 changed files with 64 additions and 67 deletions

View File

@@ -1,7 +1,7 @@
import json
import math
import os
from typing import List, Optional
from typing import List
from transformers.trainer import TRAINER_STATE_NAME
@@ -30,7 +30,7 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f)