fix system prompt and tests

Former-commit-id: 955efca677b299749f3d40d587ee310951537543
This commit is contained in:
fzc8578
2025-01-13 14:18:06 +08:00
parent 6d6acd0213
commit 07798e4aad
4 changed files with 17 additions and 10 deletions

View File

@@ -76,9 +76,15 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
if isinstance(batch_a[key], torch.Tensor):
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]):
assert len(batch_a[key]) == len(batch_b[key])
for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]):
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
assert len(batch_a[key]) == len(batch_b[key])
for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]):
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
elif isinstance(batch_a[key], list) and all(isinstance(item, list) for item in batch_a[key]) \
and len(batch_a[key])>0 and len(batch_a[key][0])>0 and isinstance(batch_a[key][0][0], torch.Tensor):
for item_a, item_b in zip(batch_a[key], batch_b[key]):
assert len(item_a) == len(item_a)
for tensor_a, tensor_b in zip(item_a, item_b):
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
else:
assert batch_a[key] == batch_b[key]