Former-commit-id: efbb32afdcf0d6aa4ca26f54c95f76dbb84f77dc
This commit is contained in:
hiyouga
2023-12-16 20:50:45 +08:00
parent f927601702
commit 790a31404a
3 changed files with 22 additions and 17 deletions

View File

@@ -206,6 +206,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if self.finetuning_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model)
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
start_index = (batch["input_ids"][0] != self.tokenizer.pad_token_id).nonzero()[0].item()
for k, v in batch.items():
batch[k] = v[:, start_index:]
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config,
@@ -220,7 +225,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
queries, responses = [], []
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
@@ -228,7 +233,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
else:
response_length = response_index[-1].item() + 1
queries.append(query[i, query_length:]) # remove padding from left
queries.append(query[i, query_start_index:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
return queries, responses