fix system prompt and tests
Former-commit-id: 955efca677b299749f3d40d587ee310951537543
This commit is contained in:
@@ -76,9 +76,15 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
|
||||
if isinstance(batch_a[key], torch.Tensor):
|
||||
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
|
||||
elif isinstance(batch_a[key], list) and all(isinstance(item, torch.Tensor) for item in batch_a[key]):
|
||||
assert len(batch_a[key]) == len(batch_b[key])
|
||||
for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]):
|
||||
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
|
||||
assert len(batch_a[key]) == len(batch_b[key])
|
||||
for tensor_a, tensor_b in zip(batch_a[key], batch_b[key]):
|
||||
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
|
||||
elif isinstance(batch_a[key], list) and all(isinstance(item, list) for item in batch_a[key]) \
|
||||
and len(batch_a[key])>0 and len(batch_a[key][0])>0 and isinstance(batch_a[key][0][0], torch.Tensor):
|
||||
for item_a, item_b in zip(batch_a[key], batch_b[key]):
|
||||
assert len(item_a) == len(item_a)
|
||||
for tensor_a, tensor_b in zip(item_a, item_b):
|
||||
assert torch.allclose(tensor_a, tensor_b, rtol=1e-4, atol=1e-5)
|
||||
else:
|
||||
assert batch_a[key] == batch_b[key]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user