Merge branch 'main' into main
Former-commit-id: 7be442f37d53a0c6324728fa1fa8e2c84d7f0fa5
This commit is contained in:
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
@@ -40,7 +54,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
num_train_epochs = gr.Textbox(value="3.0")
|
||||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
|
||||
compute_type = gr.Dropdown(choices=["bf16", "fp16", "fp32", "pure_bf16"], value="bf16")
|
||||
|
||||
input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
|
||||
elem_dict.update(
|
||||
@@ -152,10 +166,9 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
create_new_adapter = gr.Checkbox()
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
use_rslora = gr.Checkbox()
|
||||
use_dora = gr.Checkbox()
|
||||
|
||||
use_rslora = gr.Checkbox()
|
||||
use_dora = gr.Checkbox()
|
||||
use_pissa = gr.Checkbox()
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
additional_target = gr.Textbox(scale=2)
|
||||
|
||||
@@ -168,6 +181,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
create_new_adapter,
|
||||
use_rslora,
|
||||
use_dora,
|
||||
use_pissa,
|
||||
lora_target,
|
||||
additional_target,
|
||||
}
|
||||
@@ -182,6 +196,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
create_new_adapter=create_new_adapter,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
use_pissa=use_pissa,
|
||||
lora_target=lora_target,
|
||||
additional_target=additional_target,
|
||||
)
|
||||
@@ -279,7 +294,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
input_elems.update({output_dir, config_path, device_count, ds_stage, ds_offload})
|
||||
input_elems.update({output_dir, config_path, ds_stage, ds_offload})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
|
||||
Reference in New Issue
Block a user