change to right-padding, update reward score #803

Former-commit-id: baa90415bc8f5ebd423d001378b51c3a3a6c2ec7
This commit is contained in:
hiyouga
2023-09-08 20:04:31 +08:00
parent bb1b67c076
commit 612d97db6f
15 changed files with 97 additions and 59 deletions

View File

@@ -102,6 +102,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Get inputs
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model)
# Cast to training mode
@@ -110,6 +111,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Run PPO step
stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
@@ -169,7 +171,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
response_length = 1 # allow empty response
else:
response_length = response_index[-1] + 1
queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
@@ -194,7 +200,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
rewards = []
for i in range(values.size(0)):
end_index = batch["attention_mask"][i].nonzero()[-1]
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
replace_model(unwrapped_model, target="default")
return rewards
@@ -241,7 +251,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
for j in range(len(query_batch)):
start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])