use fixture
Former-commit-id: 10761985691b9f934f7689c1f82aa6dd68febcca
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
@@ -43,10 +44,14 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
|
||||
|
||||
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
@pytest.fixture
|
||||
def fix_valuehead_cpu_loading():
|
||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]):
|
||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||
self.v_head.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
|
||||
AutoModelForCausalLMWithValueHead.post_init = post_init
|
||||
|
||||
|
||||
def test_base():
|
||||
@@ -60,8 +65,8 @@ def test_base():
|
||||
compare_model(model, ref_model)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
|
||||
def test_valuehead():
|
||||
AutoModelForCausalLMWithValueHead.post_init = post_init # patch for CPU test
|
||||
model_args, _, finetuning_args, _ = get_infer_args(INFER_ARGS)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
model = load_model(
|
||||
|
||||
Reference in New Issue
Block a user