Former-commit-id: 067ba6e6cb4d8a1d95bba0a108f73008416a2865
This commit is contained in:
hiyouga
2024-12-19 12:16:30 +00:00
parent 0a465fc3ca
commit 0385c60177
6 changed files with 22 additions and 16 deletions

View File

@@ -15,6 +15,8 @@
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional
from transformers import GenerationConfig
@dataclass
class GeneratingArguments:
@@ -69,10 +71,17 @@ class GeneratingArguments:
metadata={"help": "Whether or not to remove special tokens in the decoding."},
)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self, obey_generation_config: bool = False) -> Dict[str, Any]:
args = asdict(self)
if args.get("max_new_tokens", -1) > 0:
args.pop("max_length", None)
else:
args.pop("max_new_tokens", None)
if obey_generation_config:
generation_config = GenerationConfig()
for key in list(args.keys()):
if not hasattr(generation_config, key):
args.pop(key)
return args