[trainer] fix pt loss (#7748)

* fix pt loss

* robust

* fix

* test
This commit is contained in:
hoshi-hiyouga
2025-04-17 03:15:35 +08:00
committed by GitHub
parent 86ebb219d6
commit 39169986ef
10 changed files with 34 additions and 34 deletions

View File

@@ -21,7 +21,7 @@ import re
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Optional, TypedDict, Union
import numpy as np
import torch
@@ -86,7 +86,7 @@ if TYPE_CHECKING:
pass
def _concatenate_list(input_list):
def _concatenate_list(input_list: list[Any]) -> Union[list[Any], "NDArray", "torch.Tensor"]:
r"""Concatenate a list of lists, numpy arrays or torch tensors.
Returns: