support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -16,7 +16,7 @@ import os
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_images(
|
||||
@@ -167,7 +167,7 @@ def convert_sharegpt(
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
logger.warning(f"Invalid role tag in {messages}.")
|
||||
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
||||
broken_data = True
|
||||
|
||||
aligned_messages.append(
|
||||
@@ -177,7 +177,7 @@ def convert_sharegpt(
|
||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning(f"Invalid message count in {messages}.")
|
||||
logger.warning_rank0(f"Invalid message count in {messages}.")
|
||||
broken_data = True
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||
@@ -198,7 +198,7 @@ def convert_sharegpt(
|
||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning(f"Invalid role tag in {[chosen, rejected]}.")
|
||||
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
||||
broken_data = True
|
||||
|
||||
prompt = aligned_messages
|
||||
@@ -211,7 +211,7 @@ def convert_sharegpt(
|
||||
response = aligned_messages[-1:]
|
||||
|
||||
if broken_data:
|
||||
logger.warning("Skipping this abnormal example.")
|
||||
logger.warning_rank0("Skipping this abnormal example.")
|
||||
prompt, response = [], []
|
||||
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
|
||||
|
||||
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
@@ -56,12 +56,12 @@ def merge_dataset(
|
||||
return all_datasets[0]
|
||||
elif data_args.mix_strategy == "concat":
|
||||
if data_args.streaming:
|
||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
|
||||
|
||||
return concatenate_datasets(all_datasets)
|
||||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
|
||||
@@ -20,8 +20,8 @@ import numpy as np
|
||||
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import has_tokenized_data
|
||||
from .aligner import align_dataset
|
||||
from .data_utils import merge_dataset, split_dataset
|
||||
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||
from .template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _load_single_dataset(
|
||||
@@ -51,7 +51,7 @@ def _load_single_dataset(
|
||||
r"""
|
||||
Loads a single dataset and aligns it to the standard format.
|
||||
"""
|
||||
logger.info(f"Loading dataset {dataset_attr}...")
|
||||
logger.info_rank0(f"Loading dataset {dataset_attr}...")
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
||||
data_path = dataset_attr.dataset_name
|
||||
@@ -141,7 +141,7 @@ def _load_single_dataset(
|
||||
|
||||
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
||||
dataset = dataset.select(indexes)
|
||||
logger.info(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
|
||||
logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
max_samples = min(data_args.max_samples, len(dataset))
|
||||
@@ -237,9 +237,9 @@ def get_dataset(
|
||||
# Load tokenized dataset
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||
logger.info(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
||||
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
||||
|
||||
dataset_module: Dict[str, "Dataset"] = {}
|
||||
if "train" in dataset_dict:
|
||||
@@ -290,8 +290,8 @@ def get_dataset(
|
||||
if data_args.tokenized_path is not None:
|
||||
if training_args.should_save:
|
||||
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||
logger.info(f"Tokenized dataset saved at {data_args.tokenized_path}.")
|
||||
logger.info(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
|
||||
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
|
||||
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_feedback_example(
|
||||
@@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
||||
@@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
if desirable_num == 0 or undesirable_num == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
logger.warning_rank0("Your dataset only has one preference type.")
|
||||
|
||||
return model_inputs
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_pairwise_example(
|
||||
@@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import greedy_knapsack, infer_seqlen
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_supervised_example(
|
||||
@@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
@@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
|
||||
length2indexes = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
@@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
|
||||
)
|
||||
length = len(input_ids)
|
||||
if length > data_args.cutoff_len:
|
||||
logger.warning(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
|
||||
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
|
||||
else:
|
||||
lengths.append(length)
|
||||
length2indexes[length].append(valid_num)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
from ..data_utils import Role
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_unsupervised_example(
|
||||
@@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset(
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_unsupervised_example(
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from transformers.utils.versions import require_version
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .mm_plugin import get_mm_plugin
|
||||
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
|
||||
from .mm_plugin import BasePlugin
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -275,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
|
||||
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||
|
||||
if is_added:
|
||||
logger.info(f"Add eos token: {tokenizer.eos_token}")
|
||||
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
|
||||
else:
|
||||
logger.info(f"Replace eos token: {tokenizer.eos_token}")
|
||||
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
|
||||
|
||||
if num_added_tokens > 0:
|
||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
|
||||
def _jinja_escape(content: str) -> str:
|
||||
@@ -370,7 +370,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
if data_args.tool_format is not None:
|
||||
logger.info(f"Using tool format: {data_args.tool_format}.")
|
||||
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
|
||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
|
||||
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
||||
@@ -388,21 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info(f"Add pad token: {tokenizer.pad_token}")
|
||||
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
|
||||
|
||||
if stop_words:
|
||||
num_added_tokens = tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
||||
)
|
||||
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
||||
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
|
||||
if num_added_tokens > 0:
|
||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
if tokenizer.chat_template is None or template.replace_jinja_template:
|
||||
try:
|
||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
||||
except ValueError as e:
|
||||
logger.info(f"Cannot add this chat template to tokenizer: {e}.")
|
||||
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
|
||||
|
||||
return template
|
||||
|
||||
|
||||
Reference in New Issue
Block a user