Former-commit-id: cf0758b03e9b8b4931ba790a9726b8256ee4286c
This commit is contained in:
hiyouga
2024-09-05 22:27:48 +08:00
parent 9bdba2f6a8
commit b48b47d519
4 changed files with 9 additions and 8 deletions

View File

@@ -49,17 +49,17 @@ INFER_ARGS = {
"infer_dtype": "float16",
}
CI_OS = os.environ.get("CI_OS", "")
OS_NAME = os.environ.get("OS_NAME", "")
@pytest.mark.skipif(CI_OS.startswith("windows"), reason="Skip for windows.")
@pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.")
def test_pissa_train():
model = load_train_model(**TRAIN_ARGS)
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=True)
compare_model(model, ref_model)
@pytest.mark.skipif(CI_OS.startswith("windows"), reason="Skip for windows.")
@pytest.mark.xfail(OS_NAME.startswith("windows"), reason="Known connection error on Windows.")
def test_pissa_inference():
model = load_infer_model(**INFER_ARGS)
ref_model = load_reference_model(TINY_LLAMA_PISSA, TINY_LLAMA_PISSA, use_pissa=True, is_trainable=False)