support rank0 logger

Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent ceb701c2d4
commit 093eda2ad6
42 changed files with 316 additions and 252 deletions

View File

@@ -16,7 +16,7 @@ import os
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from ..extras.logging import get_logger
from ..extras import logging
from .data_utils import Role
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
from .parser import DatasetAttr
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _convert_images(
@@ -167,7 +167,7 @@ def convert_sharegpt(
broken_data = False
for turn_idx, message in enumerate(messages):
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
logger.warning(f"Invalid role tag in {messages}.")
logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True
aligned_messages.append(
@@ -177,7 +177,7 @@ def convert_sharegpt(
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
):
logger.warning(f"Invalid message count in {messages}.")
logger.warning_rank0(f"Invalid message count in {messages}.")
broken_data = True
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
@@ -198,7 +198,7 @@ def convert_sharegpt(
chosen[dataset_attr.role_tag] not in accept_tags[-1]
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
):
logger.warning(f"Invalid role tag in {[chosen, rejected]}.")
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
broken_data = True
prompt = aligned_messages
@@ -211,7 +211,7 @@ def convert_sharegpt(
response = aligned_messages[-1:]
if broken_data:
logger.warning("Skipping this abnormal example.")
logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)

View File

@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
from ..extras.logging import get_logger
from ..extras import logging
if TYPE_CHECKING:
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from ..hparams import DataArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
@@ -56,12 +56,12 @@ def merge_dataset(
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
return interleave_datasets(
datasets=all_datasets,

View File

@@ -20,8 +20,8 @@ import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
from ..extras.logging import get_logger
from ..extras.misc import has_tokenized_data
from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from .template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _load_single_dataset(
@@ -51,7 +51,7 @@ def _load_single_dataset(
r"""
Loads a single dataset and aligns it to the standard format.
"""
logger.info(f"Loading dataset {dataset_attr}...")
logger.info_rank0(f"Loading dataset {dataset_attr}...")
data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
data_path = dataset_attr.dataset_name
@@ -141,7 +141,7 @@ def _load_single_dataset(
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
dataset = dataset.select(indexes)
logger.info(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
if data_args.max_samples is not None: # truncate dataset
max_samples = min(data_args.max_samples, len(dataset))
@@ -237,9 +237,9 @@ def get_dataset(
# Load tokenized dataset
if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
logger.info(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
dataset_module: Dict[str, "Dataset"] = {}
if "train" in dataset_dict:
@@ -290,8 +290,8 @@ def get_dataset(
if data_args.tokenized_path is not None:
if training_args.should_save:
dataset_dict.save_to_disk(data_args.tokenized_path)
logger.info(f"Tokenized dataset saved at {data_args.tokenized_path}.")
logger.info(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
sys.exit(0)

View File

@@ -15,8 +15,8 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import infer_seqlen
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _encode_feedback_example(
@@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
@@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
if desirable_num == 0 or undesirable_num == 0:
logger.warning("Your dataset only has one preference type.")
logger.warning_rank0("Your dataset only has one preference type.")
return model_inputs

View File

@@ -15,8 +15,8 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import infer_seqlen
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _encode_pairwise_example(
@@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(

View File

@@ -15,8 +15,8 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import greedy_knapsack, infer_seqlen
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _encode_supervised_example(
@@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_supervised_example(
@@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
length2indexes = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_supervised_example(
@@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
)
length = len(input_ids)
if length > data_args.cutoff_len:
logger.warning(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
else:
lengths.append(length)
length2indexes[length].append(valid_num)

View File

@@ -15,7 +15,7 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger
from ...extras import logging
from ..data_utils import Role
from .processor_utils import infer_seqlen
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from ..template import Template
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _encode_unsupervised_example(
@@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset(
model_inputs = defaultdict(list)
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1:
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue
input_ids, labels = _encode_unsupervised_example(

View File

@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version
from typing_extensions import override
from ..extras.logging import get_logger
from ..extras import logging
from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
from .mm_plugin import BasePlugin
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
@dataclass
@@ -275,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added:
logger.info(f"Add eos token: {tokenizer.eos_token}")
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
else:
logger.info(f"Replace eos token: {tokenizer.eos_token}")
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
def _jinja_escape(content: str) -> str:
@@ -370,7 +370,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
raise ValueError("Current template does not support `train_on_prompt`.")
if data_args.tool_format is not None:
logger.info(f"Using tool format: {data_args.tool_format}.")
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
@@ -388,21 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Add pad token: {tokenizer.pad_token}")
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
if stop_words:
num_added_tokens = tokenizer.add_special_tokens(
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
)
logger.info("Add {} to stop words.".format(",".join(stop_words)))
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
if tokenizer.chat_template is None or template.replace_jinja_template:
try:
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
except ValueError as e:
logger.info(f"Cannot add this chat template to tokenizer: {e}.")
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
return template