support autogptq in llama board #246
Former-commit-id: fea01226703d1534b5cf511bcb6a49e73bc86ce1
This commit is contained in:
@@ -1,13 +1,9 @@
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
try:
|
||||
from transformers.utils import (
|
||||
is_torch_bf16_cpu_available,
|
||||
@@ -106,22 +102,6 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
(*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
if unknown_args:
|
||||
logger.warning(parser.format_help())
|
||||
logger.error(f'\nGot unknown args, potentially deprecated arguments: {unknown_args}\n')
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
return (*parsed_args,)
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
|
||||
Reference in New Issue
Block a user