Former-commit-id: f7f440986b0ae3b38ea9f2da80789629d4f79ea1
This commit is contained in:
hiyouga
2024-06-16 01:06:41 +08:00
parent 14f7bfc545
commit 05f3a3c944
22 changed files with 27 additions and 25 deletions

View File

@@ -41,7 +41,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
state_dict_b = model_b.state_dict()
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
for name in state_dict_a.keys():
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
assert torch.allclose(state_dict_a[name], state_dict_b[name])
@pytest.fixture