[misc] lint code (#9395)

This commit is contained in:
Yaowei Zheng
2025-11-03 22:08:59 +08:00
committed by GitHub
parent 215580c77d
commit 3ae15da9c0
17 changed files with 82 additions and 75 deletions

View File

@@ -19,11 +19,10 @@ from transformers import AutoModelForCausalLM
class TestKernelPlugin(unittest.TestCase):
@patch('torch.accelerator.current_accelerator')
@patch("torch.accelerator.current_accelerator")
def test_apply_kernel(self, mock_get_accelerator):
mock_device = MagicMock()
mock_device.type = 'npu'
mock_device.type = "npu"
mock_get_accelerator.return_value = mock_device
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
@@ -31,7 +30,6 @@ class TestKernelPlugin(unittest.TestCase):
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
original_swiglu_forward = model.model.layers[0].mlp.forward
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm