support batch_eval_metrics, fix #4826
Former-commit-id: 3fe1df17188825f8a32fbe6a1294b4b532ce0c85
This commit is contained in:
@@ -12,14 +12,30 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...extras.misc import numpify
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import EvalPrediction
|
||||
|
||||
|
||||
def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
|
||||
return {"accuracy": np.mean(eval_preds.predictions[0] > eval_preds.predictions[1])}
|
||||
@dataclass
|
||||
class ComputeAccuracy:
|
||||
def __post_init__(self):
|
||||
self.score_dict = {"accuracy": []}
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
|
||||
if not chosen_scores.shape:
|
||||
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
|
||||
else:
|
||||
for i in range(len(chosen_scores)):
|
||||
self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i])
|
||||
|
||||
if compute_result:
|
||||
return {"accuracy": float(np.mean(self.score_dict["accuracy"]))}
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..callbacks import fix_valuehead_checkpoint
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
from .metric import compute_accuracy
|
||||
from .metric import ComputeAccuracy
|
||||
from .trainer import PairwiseTrainer
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ def run_rm(
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
compute_metrics=compute_accuracy,
|
||||
compute_metrics=ComputeAccuracy(),
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user