mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-01 20:23:37 +00:00
use pre-commit
Former-commit-id: 7cfede95df22a9ff236788f04159b6b16b8d04bb
This commit is contained in:
@@ -62,8 +62,8 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
|
||||
setattr(model, "default_head_bias", v_head_layer.bias.data.detach().clone())
|
||||
|
||||
device = v_head_layer.weight.device
|
||||
v_head_layer.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone().to(device)
|
||||
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
|
||||
v_head_layer.weight.data = model.get_buffer(f"{target}_head_weight").detach().clone().to(device)
|
||||
v_head_layer.bias.data = model.get_buffer(f"{target}_head_bias").detach().clone().to(device)
|
||||
|
||||
|
||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
|
||||
|
||||
Reference in New Issue
Block a user