Former-commit-id: 280c0f3f2cea4dfced797cc0e15f72b8b3a93542
This commit is contained in:
hiyouga
2024-09-05 03:02:59 +08:00
parent 7b01c0676c
commit 26d914b8fc
3 changed files with 4 additions and 3 deletions

View File

@@ -37,9 +37,9 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
for name in state_dict_a.keys():
if any(key in name for key in diff_keys):
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is False
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is False
else:
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5) is True
assert torch.allclose(state_dict_a[name], state_dict_b[name], rtol=1e-3, atol=1e-4) is True
def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]: