support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -16,7 +16,7 @@ import os
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_images(
|
||||
@@ -167,7 +167,7 @@ def convert_sharegpt(
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
logger.warning(f"Invalid role tag in {messages}.")
|
||||
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
||||
broken_data = True
|
||||
|
||||
aligned_messages.append(
|
||||
@@ -177,7 +177,7 @@ def convert_sharegpt(
|
||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning(f"Invalid message count in {messages}.")
|
||||
logger.warning_rank0(f"Invalid message count in {messages}.")
|
||||
broken_data = True
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||
@@ -198,7 +198,7 @@ def convert_sharegpt(
|
||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning(f"Invalid role tag in {[chosen, rejected]}.")
|
||||
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
||||
broken_data = True
|
||||
|
||||
prompt = aligned_messages
|
||||
@@ -211,7 +211,7 @@ def convert_sharegpt(
|
||||
response = aligned_messages[-1:]
|
||||
|
||||
if broken_data:
|
||||
logger.warning("Skipping this abnormal example.")
|
||||
logger.warning_rank0("Skipping this abnormal example.")
|
||||
prompt, response = [], []
|
||||
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
|
||||
Reference in New Issue
Block a user