support rank0 logger
Former-commit-id: 84528eabe560091bfd866b6a0ca864085af7529b
This commit is contained in:
@@ -37,7 +37,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
||||
from trl.models.utils import unwrap_model_for_generation
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
@@ -58,7 +58,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
@@ -112,7 +112,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
]
|
||||
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
||||
if ppo_config.log_with is not None:
|
||||
logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
|
||||
logger.warning_rank0("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
|
||||
ppo_config.log_with = None
|
||||
|
||||
# Create optimizer and scheduler
|
||||
@@ -160,7 +160,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
if self.args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
logger.info_rank0("max_steps is given, it will override any value given in num_train_epochs")
|
||||
|
||||
self.amp_context = torch.autocast(self.current_device.type)
|
||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||
@@ -216,20 +216,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.state.is_local_process_zero = self.is_local_process_zero()
|
||||
self.state.is_world_process_zero = self.is_world_process_zero()
|
||||
|
||||
if self.is_world_process_zero():
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {num_examples:,}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs:,}")
|
||||
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
|
||||
total_train_batch_size
|
||||
)
|
||||
logger.info_rank0("***** Running training *****")
|
||||
logger.info_rank0(f" Num examples = {num_examples:,}")
|
||||
logger.info_rank0(f" Num Epochs = {num_train_epochs:,}")
|
||||
logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
|
||||
logger.info_rank0(
|
||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
|
||||
total_train_batch_size
|
||||
)
|
||||
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
|
||||
logger.info(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
|
||||
logger.info(f" Total training steps = {max_steps:,}")
|
||||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]:,}")
|
||||
)
|
||||
logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
|
||||
logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
|
||||
logger.info_rank0(f" Total training steps = {max_steps:,}")
|
||||
logger.info_rank0(f" Number of trainable parameters = {count_parameters(self.model)[0]:,}")
|
||||
|
||||
dataiter = iter(self.dataloader)
|
||||
loss_meter = AverageMeter()
|
||||
@@ -269,7 +268,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
||||
self.log_stats(stats, batch, rewards)
|
||||
except Exception:
|
||||
logger.warning("Failed to save stats due to unknown errors.")
|
||||
logger.warning_rank0("Failed to save stats due to unknown errors.")
|
||||
|
||||
self.state.global_step += 1
|
||||
self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||
@@ -498,7 +497,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
if self.args.should_save:
|
||||
self._save(output_dir, state_dict=state_dict)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
logger.warning_rank0(
|
||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
||||
" use zero_to_fp32.py to recover weights"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user