use pre-commit
Former-commit-id: 7cfede95df22a9ff236788f04159b6b16b8d04bb
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -100,7 +99,7 @@ def calculate_ppl(
|
||||
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
||||
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
@@ -125,8 +124,8 @@ def calculate_ppl(
|
||||
with open(save_name, "w", encoding="utf-8") as f:
|
||||
json.dump(perplexities, f, indent=2)
|
||||
|
||||
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
|
||||
print("Perplexities have been saved at {}.".format(save_name))
|
||||
print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
|
||||
print(f"Perplexities have been saved at {save_name}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user