fix tol
Former-commit-id: bdb54bcb477126687db789bd89f2df84e424a2a3
This commit is contained in:
@@ -67,9 +67,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
|
||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||
for name in state_dict_a.keys():
|
||||
if any(key in name for key in diff_keys):
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is False
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False
|
||||
else:
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Reference in New Issue
Block a user