mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 08:33:38 +00:00
[misc] fix import error (#9299)
This commit is contained in:
@@ -20,15 +20,17 @@ from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from transformers.training_args import _convert_str_dict
|
||||
from typing_extensions import Self
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
||||
from ..extras.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelArguments:
|
||||
r"""Arguments pertaining to the model."""
|
||||
@@ -168,7 +170,7 @@ class BaseModelArguments:
|
||||
default="offload",
|
||||
metadata={"help": "Path to offload model weights."},
|
||||
)
|
||||
use_cache: bool = field(
|
||||
use_kv_cache: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||
)
|
||||
|
||||
@@ -81,6 +81,11 @@ class RayArguments:
|
||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
overwrite_output_dir: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "deprecated"},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
Seq2SeqTrainingArguments.__post_init__(self)
|
||||
RayArguments.__post_init__(self)
|
||||
|
||||
Reference in New Issue
Block a user