tiny fix
Former-commit-id: f7f440986b0ae3b38ea9f2da80789629d4f79ea1
This commit is contained in:
@@ -41,7 +41,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
|
||||
state_dict_b = model_b.state_dict()
|
||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||
for name in state_dict_a.keys():
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
|
||||
assert torch.allclose(state_dict_a[name], state_dict_b[name])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Reference in New Issue
Block a user