support rank0 logger

Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent ceb701c2d4
commit 093eda2ad6
42 changed files with 316 additions and 252 deletions

View File

@@ -20,8 +20,8 @@ import torch
from transformers import PreTrainedModel
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import get_logger
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
from .callbacks import LogCallback
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
from transformers import TrainerCallback
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
@@ -91,7 +91,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
setattr(model.config, "torch_dtype", output_dtype)
model = model.to(output_dtype)
logger.info(f"Convert model dtype to: {output_dtype}.")
logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
model.save_pretrained(
save_directory=model_args.export_dir,
@@ -117,13 +117,13 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
)
logger.info(f"Copied valuehead to {model_args.export_dir}.")
logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
shutil.copy(
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
)
logger.info(f"Copied valuehead to {model_args.export_dir}.")
logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
try:
tokenizer.padding_side = "left" # restore padding side
@@ -138,4 +138,4 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
except Exception as e:
logger.warning(f"Cannot save tokenizer, please copy the files manually: {e}.")
logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")