improve rlhf
Former-commit-id: e441780e3db256ca09a442ea9254e7ce16898a07
This commit is contained in:
@@ -12,11 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
preds, _ = eval_preds
|
||||
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
|
||||
if TYPE_CHECKING:
|
||||
from transformers import EvalPrediction
|
||||
|
||||
|
||||
def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
|
||||
return {"accuracy": np.mean(eval_preds.predictions[0] > eval_preds.predictions[1])}
|
||||
|
||||
Reference in New Issue
Block a user