update patcher

Former-commit-id: afb365e515d615dd62f791622450debab60ce5cc
This commit is contained in:
hiyouga
2024-06-19 21:27:00 +08:00
parent a7d7f79855
commit 5f5d4c1923
3 changed files with 10 additions and 7 deletions

View File

@@ -70,5 +70,5 @@ def test_upcast_lmhead_output():
tokenizer_module = load_tokenizer(model_args)
model = load_model(tokenizer_module["tokenizer"], model_args, finetuning_args, is_trainable=True)
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
outputs: "torch.Tensor" = model.lm_head(inputs)
outputs: "torch.Tensor" = model.get_output_embeddings()(inputs)
assert outputs.dtype == torch.float32