mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 08:13:38 +00:00
fix PPO trainer #551 , update readme
Former-commit-id: faead74849470cebae9e37cde5fab2a71b32aa43
This commit is contained in:
@@ -1,7 +1,4 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
@@ -18,22 +15,3 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
|
||||
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
||||
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
||||
})
|
||||
|
||||
|
||||
def cast_layernorm_dtype(
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
layer_norm_names: List[str] = LAYERNORM_NAMES,
|
||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
|
||||
|
||||
layer_norm_state_dict = {}
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
if layer_norm_params is not None:
|
||||
param.data = layer_norm_params[name] # restore float32 weights
|
||||
else:
|
||||
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
||||
param.data = param.data.to(torch.float16)
|
||||
|
||||
return model, layer_norm_state_dict
|
||||
|
||||
Reference in New Issue
Block a user