update scripts
Former-commit-id: 51d087cbc14bf3c7dfa06b8b66052cd80a6081be
This commit is contained in:
@@ -102,8 +102,9 @@ def compute_device_flops() -> float:
|
||||
|
||||
def calculate_mfu(
|
||||
model_name_or_path: str,
|
||||
batch_size: int,
|
||||
seq_length: int,
|
||||
batch_size: int = 1,
|
||||
seq_length: int = 1024,
|
||||
num_steps: int = 100,
|
||||
finetuning_type: str = "lora",
|
||||
flash_attn: str = "auto",
|
||||
deepspeed_stage: int = 0,
|
||||
@@ -129,7 +130,7 @@ def calculate_mfu(
|
||||
"output_dir": os.path.join("saves", "test_mfu"),
|
||||
"overwrite_output_dir": True,
|
||||
"per_device_train_batch_size": batch_size,
|
||||
"max_steps": 100,
|
||||
"max_steps": num_steps,
|
||||
"bf16": True,
|
||||
}
|
||||
if deepspeed_stage in [2, 3]:
|
||||
|
||||
Reference in New Issue
Block a user