support BLOOM models
Former-commit-id: 1314b6ea39a01aa8ac325e1d875ac013d43aec45
This commit is contained in:
@@ -58,7 +58,7 @@ def cast_layernorm_dtype(
|
||||
return model, layer_norm_state_dict
|
||||
|
||||
|
||||
class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
@@ -130,7 +130,7 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
unwrapped_model.config.use_cache = True
|
||||
|
||||
# Get response from LLaMA
|
||||
# Get response from model
|
||||
query_tensors: torch.Tensor = batch["input_ids"]
|
||||
response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user