support batch_eval_metrics, fix #4826

Former-commit-id: 3fe1df17188825f8a32fbe6a1294b4b532ce0c85
This commit is contained in:
hiyouga
2024-07-17 00:33:00 +08:00
parent 45367105fc
commit 8c93921952
7 changed files with 85 additions and 36 deletions

View File

@@ -12,14 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Optional
import numpy as np
from ...extras.misc import numpify
if TYPE_CHECKING:
from transformers import EvalPrediction
def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
return {"accuracy": np.mean(eval_preds.predictions[0] > eval_preds.predictions[1])}
@dataclass
class ComputeAccuracy:
def __post_init__(self):
self.score_dict = {"accuracy": []}
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
if not chosen_scores.shape:
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
else:
for i in range(len(chosen_scores)):
self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i])
if compute_result:
return {"accuracy": float(np.mean(self.score_dict["accuracy"]))}

View File

@@ -22,7 +22,7 @@ from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..callbacks import fix_valuehead_checkpoint
from ..trainer_utils import create_modelcard_and_push
from .metric import compute_accuracy
from .metric import ComputeAccuracy
from .trainer import PairwiseTrainer
@@ -55,7 +55,7 @@ def run_rm(
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=compute_accuracy,
compute_metrics=ComputeAccuracy(),
**dataset_module,
**tokenizer_module,
)