Deprecate reserved_label_len arg Former-commit-id: 4b6568984c0be4b31e7aa91b7c0d52b7f7b12b0b
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets
|
||||
|
||||
@@ -30,6 +30,9 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
|
||||
|
||||
@unique
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
@@ -39,13 +42,6 @@ class Role(str, Enum):
|
||||
OBSERVATION = "observation"
|
||||
|
||||
|
||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
||||
max_target_len = max(max_target_len, reserved_label_len)
|
||||
max_source_len = max_len - min(max_target_len, target_len)
|
||||
return max_source_len, max_target_len
|
||||
|
||||
|
||||
def merge_dataset(
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||
data_args: "DataArguments",
|
||||
|
||||
Reference in New Issue
Block a user