[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -62,5 +62,5 @@ def test_upcast_layernorm():
|
||||
def test_upcast_lmhead_output():
|
||||
model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS)
|
||||
inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())
|
||||
outputs: "torch.Tensor" = model.get_output_embeddings()(inputs)
|
||||
outputs: torch.Tensor = model.get_output_embeddings()(inputs)
|
||||
assert outputs.dtype == torch.float32
|
||||
|
||||
Reference in New Issue
Block a user