update scripts

Former-commit-id: 51d087cbc14bf3c7dfa06b8b66052cd80a6081be
This commit is contained in:
hiyouga
2024-09-08 14:17:41 +08:00
parent eb5af3d90b
commit 3259397f89
6 changed files with 24 additions and 11 deletions

View File

@@ -102,8 +102,9 @@ def compute_device_flops() -> float:
def calculate_mfu(
model_name_or_path: str,
batch_size: int,
seq_length: int,
batch_size: int = 1,
seq_length: int = 1024,
num_steps: int = 100,
finetuning_type: str = "lora",
flash_attn: str = "auto",
deepspeed_stage: int = 0,
@@ -129,7 +130,7 @@ def calculate_mfu(
"output_dir": os.path.join("saves", "test_mfu"),
"overwrite_output_dir": True,
"per_device_train_batch_size": batch_size,
"max_steps": 100,
"max_steps": num_steps,
"bf16": True,
}
if deepspeed_stage in [2, 3]: