remove visual_inputs, fix qlora
Former-commit-id: be30c01c4f1482520ece770bd54c6a4837c26f0a
This commit is contained in:
@@ -16,15 +16,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
r"""
|
||||
@@ -121,10 +118,6 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
||||
)
|
||||
visual_inputs: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
|
||||
)
|
||||
moe_aux_loss_coef: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||
@@ -225,19 +218,31 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||
)
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
||||
)
|
||||
device_map: Optional[Union[str, Dict[str, Any]]] = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
||||
)
|
||||
model_max_length: Optional[int] = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
|
||||
)
|
||||
block_diag_attn: bool = field(
|
||||
default=False,
|
||||
init=False,
|
||||
metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.compute_dtype: Optional["torch.dtype"] = None
|
||||
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
||||
self.model_max_length: Optional[int] = None
|
||||
self.block_diag_attn: bool = False
|
||||
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
|
||||
if self.visual_inputs and self.use_unsloth:
|
||||
raise ValueError("Unsloth does not support MLLM yet. Stay tuned.")
|
||||
|
||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user