mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-03 21:03:10 +00:00
tiny fix
Former-commit-id: 97ba2027bb1ddc01a3c824c40d5a180828810c2c
This commit is contained in:
@@ -50,7 +50,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
def _prepare_deepspeed(self, model: "PreTrainedModelWrapper"):
|
||||
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
||||
# adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
||||
if model is not None:
|
||||
@@ -75,7 +75,8 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
||||
if config_kwargs["zero_optimization"]["stage"] != 3:
|
||||
config_kwargs["zero_optimization"]["stage"] = 0
|
||||
# lazy load
|
||||
|
||||
# Lazy load
|
||||
import deepspeed # type: ignore
|
||||
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
||||
model.eval()
|
||||
|
||||
Reference in New Issue
Block a user