refactor evaluation, upgrade trl to 074
Former-commit-id: ed09ebe2c1926ffdb0520b3866f7fd03a9aed046
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
|
||||
try:
|
||||
@@ -17,6 +19,7 @@ except ImportError:
|
||||
_is_bf16_available = torch.cuda.is_bf16_supported()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
@@ -74,7 +77,7 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def get_logits_processor() -> LogitsProcessorList:
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
@@ -93,6 +96,17 @@ def torch_gc() -> None:
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
return parser.parse_args_into_dataclasses()
|
||||
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
|
||||
Reference in New Issue
Block a user