bf16 by default, gemma2 attns

Gemma2 finetuning cannot work until merging https://github.com/huggingface/transformers/pull/31674


Former-commit-id: da66c32c7be0adc28d2185b23e9f62d56acb961c
This commit is contained in:
hiyouga
2024-06-28 06:00:26 +08:00
parent cfdf5a5a78
commit fda2cf677b
3 changed files with 9 additions and 3 deletions

View File

@@ -54,7 +54,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
num_train_epochs = gr.Textbox(value="3.0")
max_grad_norm = gr.Textbox(value="1.0")
max_samples = gr.Textbox(value="100000")
compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
compute_type = gr.Dropdown(choices=["bf16", "fp16", "fp32", "pure_bf16"], value="bf16")
input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
elem_dict.update(