Former-commit-id: bdb54bcb477126687db789bd89f2df84e424a2a3
This commit is contained in:
hiyouga
2024-06-16 01:38:44 +08:00
parent 8393b08666
commit 727943f078
4 changed files with 5 additions and 4 deletions

View File

@@ -41,7 +41,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
state_dict_b = model_b.state_dict()
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
for name in state_dict_a.keys():
assert torch.allclose(state_dict_a[name], state_dict_b[name])
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5)
@pytest.fixture