support rank0 logger

Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent ceb701c2d4
commit 093eda2ad6
42 changed files with 316 additions and 252 deletions

View File

@@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled
from ..extras.logging import get_logger
from ..extras import logging
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def _setup_full_tuning(
@@ -45,7 +45,7 @@ def _setup_full_tuning(
if not is_trainable:
return
logger.info("Fine-tuning method: Full")
logger.info_rank0("Fine-tuning method: Full")
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
@@ -64,7 +64,7 @@ def _setup_freeze_tuning(
if not is_trainable:
return
logger.info("Fine-tuning method: Freeze")
logger.info_rank0("Fine-tuning method: Freeze")
if hasattr(model.config, "text_config"): # composite models
config = getattr(model.config, "text_config")
else:
@@ -133,7 +133,7 @@ def _setup_freeze_tuning(
else:
param.requires_grad_(False)
logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
def _setup_lora_tuning(
@@ -145,7 +145,7 @@ def _setup_lora_tuning(
cast_trainable_params_to_fp32: bool,
) -> "PeftModel":
if is_trainable:
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None
@@ -182,7 +182,7 @@ def _setup_lora_tuning(
model = model.merge_and_unload()
if len(adapter_to_merge) > 0:
logger.info(f"Merged {len(adapter_to_merge)} adapter(s).")
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
if adapter_to_resume is not None: # resume lora training
if model_args.use_unsloth:
@@ -190,7 +190,7 @@ def _setup_lora_tuning(
else:
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
@@ -219,7 +219,7 @@ def _setup_lora_tuning(
module_names.add(name.split(".")[-1])
finetuning_args.additional_target = module_names
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
peft_kwargs = {
"r": finetuning_args.lora_rank,
@@ -236,10 +236,10 @@ def _setup_lora_tuning(
else:
if finetuning_args.pissa_init:
if finetuning_args.pissa_iter == -1:
logger.info("Using PiSSA initialization.")
logger.info_rank0("Using PiSSA initialization.")
peft_kwargs["init_lora_weights"] = "pissa"
else:
logger.info(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
lora_config = LoraConfig(
@@ -284,11 +284,11 @@ def init_adapter(
if not is_trainable:
pass
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
else:
logger.info("Upcasting trainable params to float32.")
logger.info_rank0("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True
if finetuning_args.finetuning_type == "full":