@@ -1,6 +1,8 @@
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
@@ -216,3 +218,13 @@ class ModelArguments:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
|
||||
arg_dict = old_arg.to_dict()
|
||||
arg_dict.update(**kwargs)
|
||||
new_arg = cls(**arg_dict)
|
||||
new_arg.compute_dtype = old_arg.compute_dtype
|
||||
new_arg.device_map = old_arg.device_map
|
||||
new_arg.model_max_length = old_arg.model_max_length
|
||||
return new_arg
|
||||
|
||||
Reference in New Issue
Block a user