fix lora target

Former-commit-id: d822e41e7ac7e310ee49e347fc45754284ce30b8
This commit is contained in:
hiyouga
2023-09-09 17:04:45 +08:00
parent 7143c551ab
commit f91c5f2638
7 changed files with 63 additions and 43 deletions

View File

@@ -2,9 +2,9 @@ import os
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from transformers import TrainerState, TrainerControl
from transformers import GenerationConfig, TrainerState, TrainerControl
from trl import PPOTrainer
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
@@ -78,10 +78,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate`
gen_kwargs = self.generating_args.to_dict()
gen_kwargs["eos_token_id"] = [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
generating_args = self.generating_args.to_dict()
generating_args.update(dict(
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id
))
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
@@ -103,7 +104,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self.model.eval()
# Get inputs
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
queries, responses = self.get_inputs(batch, length_sampler, generating_args)
self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model)
@@ -152,32 +153,36 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
def get_inputs(
self,
batch: Dict[str, torch.Tensor],
length_sampler: Optional[Callable] = None,
**generation_kwargs
length_sampler: Callable,
generating_args: Dict[str, Any]
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
"""
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
generating_args["max_new_tokens"] = length_sampler()
gen_kwargs = dict(
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
**batch
)
input_ids = batch["input_ids"]
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False
response: torch.Tensor = unwrapped_model.generate(**gen_kwargs)
query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu()
queries, responses = [], []
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
response_length = 1 # allow empty response
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
response_length = response_index[-1] + 2 # save the EOS token
else:
response_length = response_index[-1] + 1
queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
@@ -204,7 +209,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
rewards = []
for i in range(values.size(0)):
end_index = batch["attention_mask"][i].nonzero()[-1]
end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
replace_model(unwrapped_model, target="default")