Former-commit-id: cf28e7418c2eb07e86923a53ef832ef218e45af1
This commit is contained in:
hiyouga
2024-09-30 23:28:55 +08:00
parent cf72aec098
commit 20ee1d2e19
4 changed files with 30 additions and 20 deletions

View File

@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field, fields
from dataclasses import dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union
import torch
@@ -308,20 +308,18 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
raise ValueError("Quantization dataset is necessary for exporting.")
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@classmethod
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
arg_dict = old_arg.to_dict()
arg_dict.update(**kwargs)
for attr in fields(cls):
if not attr.init:
arg_dict.pop(attr.name)
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
init_args, lazy_args = {}, {}
for attr in fields(source):
if attr.init:
init_args[attr.name] = getattr(source, attr.name)
else:
lazy_args[attr.name] = getattr(source, attr.name)
new_arg = cls(**arg_dict)
new_arg.compute_dtype = old_arg.compute_dtype
new_arg.device_map = old_arg.device_map
new_arg.model_max_length = old_arg.model_max_length
new_arg.block_diag_attn = old_arg.block_diag_attn
return new_arg
init_args.update(kwargs)
result = cls(**init_args)
for name, value in lazy_args.items():
setattr(result, name, value)
return result