refactor export, fix #1190

Former-commit-id: 30e60e37023a7c4a2db033ffec0542efa3d5cdfb
This commit is contained in:
hiyouga
2023-10-15 16:01:48 +08:00
parent 68330eab2a
commit c2e84d4558
9 changed files with 52 additions and 49 deletions

View File

@@ -1,5 +1,4 @@
from .data_args import DataArguments
from .finetuning_args import FinetuningArguments
from .general_args import GeneralArguments
from .generating_args import GeneratingArguments
from .model_args import ModelArguments

View File

@@ -8,6 +8,10 @@ class FinetuningArguments:
r"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
default="lora",
metadata={"help": "Which fine-tuning method to use."}

View File

@@ -46,6 +46,10 @@ class ModelArguments:
default=None,
metadata={"help": "Adopt scaled rotary positional embeddings."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
)
flash_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enable FlashAttention-2 for faster training."}
@@ -54,14 +58,14 @@ class ModelArguments:
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
)
reward_model: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
)
plot_loss: Optional[bool] = field(
default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
@@ -70,9 +74,9 @@ class ModelArguments:
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
export_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory to save the exported model."}
)
def __post_init__(self):