support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
@@ -29,8 +28,8 @@ from transformers.training_args import ParallelMode
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import CHECKPOINT_NAMES
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
@@ -39,7 +38,7 @@ from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
check_dependencies()
|
||||
@@ -73,8 +72,8 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
||||
return (*parsed_args,)
|
||||
|
||||
|
||||
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
def _set_transformers_logging() -> None:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
@@ -104,7 +103,7 @@ def _verify_model_args(
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
if data_args.template == "yi" and model_args.use_fast_tokenizer:
|
||||
logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||
model_args.use_fast_tokenizer = False
|
||||
|
||||
|
||||
@@ -261,7 +260,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if data_args.neat_packing and not data_args.packing:
|
||||
logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
||||
logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
||||
data_args.packing = True
|
||||
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
@@ -274,22 +273,26 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
and model_args.resize_vocab
|
||||
and finetuning_args.additional_target is None
|
||||
):
|
||||
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
|
||||
logger.warning_rank0(
|
||||
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
|
||||
)
|
||||
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||
logger.warning("We recommend enable mixed precision training.")
|
||||
logger.warning_rank0("We recommend enable mixed precision training.")
|
||||
|
||||
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
|
||||
logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.")
|
||||
logger.warning_rank0(
|
||||
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
|
||||
)
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
||||
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
||||
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
|
||||
|
||||
# Post-process training arguments
|
||||
if (
|
||||
@@ -297,13 +300,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||
can_resume_from_checkpoint = False
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
logger.warning("Cannot resume from checkpoint in current stage.")
|
||||
logger.warning_rank0("Cannot resume from checkpoint in current stage.")
|
||||
training_args.resume_from_checkpoint = None
|
||||
else:
|
||||
can_resume_from_checkpoint = True
|
||||
@@ -323,15 +326,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args.resume_from_checkpoint = last_checkpoint
|
||||
logger.info(f"Resuming training from {training_args.resume_from_checkpoint}.")
|
||||
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
||||
logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
|
||||
logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
||||
|
||||
if (
|
||||
finetuning_args.stage in ["rm", "ppo"]
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
and training_args.resume_from_checkpoint is not None
|
||||
):
|
||||
logger.warning(
|
||||
logger.warning_rank0(
|
||||
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user