Former-commit-id: 28cac0e325bfd7a6c0c344ad2d46511613190cd7
This commit is contained in:
hiyouga
2024-07-24 18:33:39 +08:00
parent ff5ba97970
commit 211038584a
2 changed files with 15 additions and 15 deletions

View File

@@ -36,9 +36,11 @@ def calculate_flops(
"""
with get_accelerator().device(0):
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.engine.model.device)
input_dict = {"input_ids": fake_input, "labels": fake_input.clone()}
flops, macs, params = get_model_profile(chat_model.model, kwargs=input_dict, print_profile=True, detailed=True)
flops, macs, params = get_model_profile(
chat_model.engine.model, kwargs=input_dict, print_profile=True, detailed=True
)
print("FLOPs:", flops)
print("MACs:", macs)
print("Params:", params)