mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 08:33:38 +00:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -39,7 +39,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
|
||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
|
||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||
@@ -69,7 +69,7 @@ class Evaluator:
|
||||
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
|
||||
|
||||
@torch.inference_mode()
|
||||
def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
|
||||
def batch_inference(self, batch_input: dict[str, "torch.Tensor"]) -> list[str]:
|
||||
logits = self.model(**batch_input).logits
|
||||
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
||||
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
||||
@@ -88,7 +88,7 @@ class Evaluator:
|
||||
)
|
||||
|
||||
with open(mapping, encoding="utf-8") as f:
|
||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||
categorys: dict[str, dict[str, str]] = json.load(f)
|
||||
|
||||
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||
@@ -136,7 +136,7 @@ class Evaluator:
|
||||
pbar.close()
|
||||
self._save_results(category_corrects, results)
|
||||
|
||||
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
|
||||
def _save_results(self, category_corrects: dict[str, "NDArray"], results: dict[str, dict[int, str]]) -> None:
|
||||
score_info = "\n".join(
|
||||
[
|
||||
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
|
||||
|
||||
Reference in New Issue
Block a user