support BLOOM models

Former-commit-id: 1314b6ea39a01aa8ac325e1d875ac013d43aec45
This commit is contained in:
hiyouga
2023-05-31 16:54:06 +08:00
parent 181c776b58
commit 693c049eac
16 changed files with 134 additions and 90 deletions

View File

@@ -58,7 +58,7 @@ def cast_layernorm_dtype(
return model, layer_norm_state_dict
class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r"""
Inherits PPOTrainer.
"""
@@ -130,7 +130,7 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
# Get response from LLaMA
# Get response from model
query_tensors: torch.Tensor = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs)