add npu examples

Former-commit-id: 0f21e68e2dbd84c820d66d5c6d980004efc51d51
This commit is contained in:
hiyouga
2024-05-14 23:32:53 +08:00
parent 0a82e15e7c
commit ba0da83031
8 changed files with 85 additions and 19 deletions

View File

@@ -1,9 +1,10 @@
import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict
import torch
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras.logging import get_logger
@@ -44,6 +45,10 @@ def patch_config(
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if is_torch_npu_available():
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
configure_attn_implementation(config, model_args)
configure_rope(config, model_args, is_trainable)
configure_longlora(config, model_args, is_trainable)
@@ -56,7 +61,7 @@ def patch_config(
logger.info("Using KV cache for faster generation.")
if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn)
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype)

View File

@@ -22,7 +22,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
elif model_args.flash_attn == "sdpa":
if not is_sdpa_available():
logger.warning("Torch>=2.1.1 is required for SDPA attention.")
logger.warning("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
@@ -52,4 +52,4 @@ def print_attn_implementation(config: "PretrainedConfig") -> None:
elif attn_implementation == "sdpa":
logger.info("Using torch SDPA for faster training and inference.")
else:
logger.info("Using vanilla Attention implementation.")
logger.info("Using vanilla attention implementation.")

View File

@@ -1,8 +1,3 @@
import os
import torch
from transformers import is_torch_npu_available
from llmtuner.train.tuner import run_exp
@@ -16,7 +11,4 @@ def _mp_fn(index):
if __name__ == "__main__":
if is_torch_npu_available():
use_jit_compile = os.getenv('JIT_COMPILE', 'False').lower() in ['true', '1']
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
main()