Former-commit-id: 7d0c4bd394fc3cba197db1719f1164b9dd66ac21
This commit is contained in:
hiyouga
2024-07-17 00:47:00 +08:00
parent 8c93921952
commit 341225a405
2 changed files with 30 additions and 6 deletions

View File

@@ -26,8 +26,16 @@ if TYPE_CHECKING:
@dataclass
class ComputeAccuracy:
def __post_init__(self):
def _dump(self) -> Optional[Dict[str, float]]:
result = None
if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
self.score_dict = {"accuracy": []}
return result
def __post_init__(self):
self._dump()
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
@@ -38,4 +46,4 @@ class ComputeAccuracy:
self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i])
if compute_result:
return {"accuracy": float(np.mean(self.score_dict["accuracy"]))}
return self._dump()