use fp16 model, add logcallback
Former-commit-id: bea275d51338b49ce855eec0178e759607265e3d
This commit is contained in:
@@ -4,15 +4,14 @@ import torch
|
||||
from tqdm import tqdm
|
||||
from typing import Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.trainer import TrainerState
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerState
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
|
||||
from trl.core import LengthSampler
|
||||
from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
|
||||
|
||||
from .peft_trainer import PeftTrainer
|
||||
from .peft_trainer import PeftTrainer, LogCallback
|
||||
|
||||
from .config import FinetuningArguments
|
||||
|
||||
@@ -40,15 +39,41 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
|
||||
})
|
||||
|
||||
|
||||
def cast_layernorm_dtype(
|
||||
model: AutoModelForCausalLMWithValueHead,
|
||||
layer_norm_names: List[str] = ["layernorm"], # for chatglm setting
|
||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
||||
|
||||
layer_norm_state_dict = {}
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
if layer_norm_params is not None:
|
||||
param.data = layer_norm_params[name] # restore float32 weights
|
||||
else:
|
||||
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
||||
param.data = param.data.to(torch.float16)
|
||||
|
||||
return model, layer_norm_state_dict
|
||||
|
||||
|
||||
class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
|
||||
def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: List[LogCallback],
|
||||
**kwargs
|
||||
):
|
||||
PPOTrainer.__init__(self, **kwargs)
|
||||
self.args = training_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.log_callback = callbacks[0]
|
||||
self.state = TrainerState()
|
||||
self.data_collator = self.accelerator.prepare(kwargs["data_collator"])
|
||||
|
||||
@@ -63,6 +88,11 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||
num_train_epochs = self.args.num_train_epochs
|
||||
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch)
|
||||
|
||||
self.state.max_steps = max_steps
|
||||
self.state.num_train_epochs = num_train_epochs
|
||||
self.state.is_local_process_zero = self.is_local_process_zero()
|
||||
self.state.is_world_process_zero = self.is_world_process_zero()
|
||||
|
||||
if self.is_world_process_zero():
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {num_examples}")
|
||||
@@ -144,6 +174,7 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||
print(logs)
|
||||
logs["step"] = step
|
||||
self.state.log_history.append(logs)
|
||||
self.log_callback.on_log(self.args, self.state, None)
|
||||
loss_meter.reset()
|
||||
reward_meter.reset()
|
||||
|
||||
@@ -154,8 +185,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||
def generate(
|
||||
self,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
length_sampler: Callable = None,
|
||||
return_prompt: bool = True,
|
||||
length_sampler: Optional[Callable] = None,
|
||||
return_prompt: Optional[bool] = True,
|
||||
**generation_kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
@@ -163,6 +194,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
|
||||
|
||||
if length_sampler is not None:
|
||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
||||
|
||||
@@ -175,6 +208,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||
if unwrapped_model.pretrained_model.generation_config._from_model_config:
|
||||
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
||||
|
||||
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
|
||||
|
||||
if not return_prompt and not self.is_encoder_decoder:
|
||||
return response[:, inputs["input_ids"].size(1):]
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user