support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
from .model_utils.quantization import QuantizationMethod
|
||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _setup_full_tuning(
|
||||
@@ -45,7 +45,7 @@ def _setup_full_tuning(
|
||||
if not is_trainable:
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Full")
|
||||
logger.info_rank0("Fine-tuning method: Full")
|
||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||
for name, param in model.named_parameters():
|
||||
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||
@@ -64,7 +64,7 @@ def _setup_freeze_tuning(
|
||||
if not is_trainable:
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
logger.info_rank0("Fine-tuning method: Freeze")
|
||||
if hasattr(model.config, "text_config"): # composite models
|
||||
config = getattr(model.config, "text_config")
|
||||
else:
|
||||
@@ -133,7 +133,7 @@ def _setup_freeze_tuning(
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
|
||||
logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
|
||||
|
||||
|
||||
def _setup_lora_tuning(
|
||||
@@ -145,7 +145,7 @@ def _setup_lora_tuning(
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
) -> "PeftModel":
|
||||
if is_trainable:
|
||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
|
||||
adapter_to_resume = None
|
||||
|
||||
@@ -182,7 +182,7 @@ def _setup_lora_tuning(
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(adapter_to_merge) > 0:
|
||||
logger.info(f"Merged {len(adapter_to_merge)} adapter(s).")
|
||||
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
if model_args.use_unsloth:
|
||||
@@ -190,7 +190,7 @@ def _setup_lora_tuning(
|
||||
else:
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
@@ -219,7 +219,7 @@ def _setup_lora_tuning(
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
finetuning_args.additional_target = module_names
|
||||
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
@@ -236,10 +236,10 @@ def _setup_lora_tuning(
|
||||
else:
|
||||
if finetuning_args.pissa_init:
|
||||
if finetuning_args.pissa_iter == -1:
|
||||
logger.info("Using PiSSA initialization.")
|
||||
logger.info_rank0("Using PiSSA initialization.")
|
||||
peft_kwargs["init_lora_weights"] = "pissa"
|
||||
else:
|
||||
logger.info(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
|
||||
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
|
||||
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
|
||||
|
||||
lora_config = LoraConfig(
|
||||
@@ -284,11 +284,11 @@ def init_adapter(
|
||||
if not is_trainable:
|
||||
pass
|
||||
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||
else:
|
||||
logger.info("Upcasting trainable params to float32.")
|
||||
logger.info_rank0("Upcasting trainable params to float32.")
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
if finetuning_args.finetuning_type == "full":
|
||||
|
||||
Reference in New Issue
Block a user