fix ppo trainer
Former-commit-id: ca5b5823b03822ef899405d233a82396be997f44
This commit is contained in:
@@ -203,7 +203,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
r"""
|
r"""
|
||||||
Generates model's responses given queries.
|
Generates model's responses given queries.
|
||||||
"""
|
"""
|
||||||
if self.finetuning_args.upcast_layernorm:
|
if self.model_args.upcast_layernorm:
|
||||||
layernorm_params = dump_layernorm(self.model)
|
layernorm_params = dump_layernorm(self.model)
|
||||||
|
|
||||||
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
if batch["input_ids"].size(0) == 1: # handle llama2 ppo with gradient accumulation > 1
|
||||||
@@ -218,7 +218,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
**batch
|
**batch
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.finetuning_args.upcast_layernorm:
|
if self.model_args.upcast_layernorm:
|
||||||
restore_layernorm(self.model, layernorm_params)
|
restore_layernorm(self.model, layernorm_params)
|
||||||
|
|
||||||
query = batch["input_ids"].detach().cpu()
|
query = batch["input_ids"].detach().cpu()
|
||||||
|
|||||||
Reference in New Issue
Block a user