support FlashAttention2

Former-commit-id: 23e56c5554b948d4f08ad87849b261eafd2c7890
This commit is contained in:
hiyouga
2023-09-10 20:43:56 +08:00
parent b481ad58e6
commit a402161631
9 changed files with 875 additions and 115 deletions

View File

@@ -43,6 +43,10 @@ class ModelArguments:
default=None,
metadata={"help": "Adopt scaled rotary positional embeddings."}
)
flash_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enable flash attention for faster training."}
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}