support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
@@ -34,8 +33,8 @@ from transformers.utils import (
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.logging import LoggerHandler, get_logger
|
||||
from ..extras.misc import get_peak_memory
|
||||
|
||||
|
||||
@@ -48,7 +47,7 @@ if TYPE_CHECKING:
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def fix_valuehead_checkpoint(
|
||||
@@ -92,7 +91,7 @@ def fix_valuehead_checkpoint(
|
||||
else:
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||
|
||||
logger.info(f"Value head model saved at: {output_dir}")
|
||||
logger.info_rank0(f"Value head model saved at: {output_dir}")
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
@@ -145,7 +144,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
if args.should_save:
|
||||
model = kwargs.pop("model")
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
logger.info(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
|
||||
logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
|
||||
if isinstance(model, PeftModel):
|
||||
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||
@@ -159,7 +158,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
|
||||
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
|
||||
logger.info(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
|
||||
logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
|
||||
# 1. save a pissa backup with init_lora_weights: True
|
||||
# 2. save a converted lora with init_lora_weights: pissa
|
||||
# 3. load the pissa backup with init_lora_weights: True
|
||||
@@ -200,8 +199,8 @@ class LogCallback(TrainerCallback):
|
||||
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
||||
if self.webui_mode:
|
||||
signal.signal(signal.SIGABRT, self._set_abort)
|
||||
self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
||||
logging.add_handler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
def _set_abort(self, signum, frame) -> None:
|
||||
@@ -243,7 +242,7 @@ class LogCallback(TrainerCallback):
|
||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
and args.overwrite_output_dir
|
||||
):
|
||||
logger.warning("Previous trainer log in this folder will be deleted.")
|
||||
logger.warning_once("Previous trainer log in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
|
||||
@override
|
||||
@@ -310,7 +309,7 @@ class LogCallback(TrainerCallback):
|
||||
|
||||
logs = {k: v for k, v in logs.items() if v is not None}
|
||||
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
|
||||
logger.info(
|
||||
logger.info_rank0(
|
||||
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
|
||||
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A")
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user