[misc] lint code (#9395)
This commit is contained in:
@@ -19,11 +19,10 @@ from transformers import AutoModelForCausalLM
|
||||
|
||||
|
||||
class TestKernelPlugin(unittest.TestCase):
|
||||
|
||||
@patch('torch.accelerator.current_accelerator')
|
||||
@patch("torch.accelerator.current_accelerator")
|
||||
def test_apply_kernel(self, mock_get_accelerator):
|
||||
mock_device = MagicMock()
|
||||
mock_device.type = 'npu'
|
||||
mock_device.type = "npu"
|
||||
mock_get_accelerator.return_value = mock_device
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
@@ -31,7 +30,6 @@ class TestKernelPlugin(unittest.TestCase):
|
||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||
|
||||
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
|
||||
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
|
||||
|
||||
Reference in New Issue
Block a user