refactor mm training

Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
This commit is contained in:
hiyouga
2024-08-30 02:14:31 +08:00
parent 77c2c7076b
commit c62a6ca59d
29 changed files with 499 additions and 312 deletions

View File

@@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
require_version("transformers>=4.41.2,<=4.44.3", "To fix: pip install transformers>=4.41.2,<=4.44.3")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward