[model] add llama4 (#7611)
This commit is contained in:
@@ -243,7 +243,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||
mini_batch = {
|
||||
"input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
|
||||
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size]
|
||||
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size],
|
||||
}
|
||||
mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
|
||||
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
|
||||
|
||||
Reference in New Issue
Block a user