support rank0 logger

Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent ceb701c2d4
commit 093eda2ad6
42 changed files with 316 additions and 252 deletions

View File

@@ -15,8 +15,8 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import infer_seqlen
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _encode_feedback_example(
@@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
@@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.")
logger.warning_rank0("Your dataset only has one preference type.")
return model_inputs

View File

@@ -15,8 +15,8 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import infer_seqlen
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _encode_pairwise_example(
@@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(

View File

@@ -15,8 +15,8 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import greedy_knapsack, infer_seqlen
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _encode_supervised_example(
@@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_supervised_example(
@@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_supervised_example(
@@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
else:
lengths.append(length)
length2indexes[length].append(valid_num)

View File

@@ -15,7 +15,7 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger
from ...extras import logging
from ..data_utils import Role
from .processor_utils import infer_seqlen
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _encode_unsupervised_example(
@@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset(
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_unsupervised_example(