[model] add llama4 (#7611)

This commit is contained in:
hoshi-hiyouga
2025-04-06 13:42:31 +08:00
committed by GitHub
parent d4cfa9507e
commit 831e7f1cfd
11 changed files with 167 additions and 8 deletions

View File

@@ -243,7 +243,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch = {
"input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size]
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size],
}
mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)