refactor model_dtype, fix PPO trainer

Former-commit-id: 3e17ee5afbcb823a7c9a2f91864b3750cd79edb4
This commit is contained in:
hiyouga
2023-10-11 23:16:01 +08:00
parent a2d08ce961
commit 3198a7e5f4
10 changed files with 104 additions and 119 deletions

View File

@@ -10,14 +10,15 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import GeneratingArguments
from llmtuner.hparams import ModelArguments, GeneratingArguments
logger = get_logger(__name__)
@@ -30,10 +31,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
def __init__(
self,
model_args: "ModelArguments",
training_args: "Seq2SeqTrainingArguments",
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
compute_dtype: torch.dtype,
**kwargs
):
PPOTrainer.__init__(self, **kwargs)
@@ -41,11 +42,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
self.args = training_args
self.generating_args = generating_args
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
self.compute_dtype = compute_dtype
self.model_args = model_args
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
**generating_args.to_dict()
)
self.state = TrainerState()
self.control = TrainerControl()
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
def ppo_train(self) -> None:
r"""
@@ -74,13 +80,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate`
generating_args = self.generating_args.to_dict()
generating_args.update(dict(
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id
))
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
steps_trained = 0
@@ -98,7 +97,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.eval()
# Get inputs
queries, responses = self.get_inputs(batch, generating_args)
queries, responses = self.get_inputs(batch)
self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model)
@@ -159,27 +158,24 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
)
@torch.no_grad()
def get_inputs(
self,
batch: Dict[str, torch.Tensor],
generating_args: Dict[str, Any]
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
"""
gen_kwargs = dict(
generation_config=GenerationConfig(**generating_args),
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config,
logits_processor=get_logits_processor(),
**batch
)
input_ids = batch["input_ids"]
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**gen_kwargs)
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu()
if self.model_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params)
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
queries, responses = [], []
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]