Deprecate reserved_label_len arg


Former-commit-id: 4b6568984c0be4b31e7aa91b7c0d52b7f7b12b0b
This commit is contained in:
hiyouga
2024-07-01 01:19:27 +08:00
parent b670fb57db
commit 67d2eb6b2a
13 changed files with 329 additions and 223 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union
from datasets import concatenate_datasets, interleave_datasets
@@ -30,6 +30,9 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
@unique
class Role(str, Enum):
USER = "user"
@@ -39,13 +42,6 @@ class Role(str, Enum):
OBSERVATION = "observation"
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
max_target_len = int(max_len * (target_len / (source_len + target_len)))
max_target_len = max(max_target_len, reserved_label_len)
max_source_len = max_len - min(max_target_len, target_len)
return max_source_len, max_target_len
def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]],
data_args: "DataArguments",