allow non-packing pretraining
Former-commit-id: 3fee5cc5a3db9ce874ad90f2500ec092d904bd4e
This commit is contained in:
@@ -4,7 +4,7 @@ import gradio as gr
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
from ...extras.constants import TRAINING_STAGES
|
||||
from ..common import DEFAULT_DATA_DIR, list_adapters, list_dataset
|
||||
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
|
||||
from ..components.data import create_preview_box
|
||||
from ..utils import gen_plot
|
||||
|
||||
@@ -78,7 +78,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Row():
|
||||
resize_vocab = gr.Checkbox()
|
||||
sft_packing = gr.Checkbox()
|
||||
packing = gr.Checkbox()
|
||||
upcast_layernorm = gr.Checkbox()
|
||||
use_llama_pro = gr.Checkbox()
|
||||
shift_attn = gr.Checkbox()
|
||||
@@ -91,7 +91,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
neftune_alpha,
|
||||
optim,
|
||||
resize_vocab,
|
||||
sft_packing,
|
||||
packing,
|
||||
upcast_layernorm,
|
||||
use_llama_pro,
|
||||
shift_attn,
|
||||
@@ -106,7 +106,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
neftune_alpha=neftune_alpha,
|
||||
optim=optim,
|
||||
resize_vocab=resize_vocab,
|
||||
sft_packing=sft_packing,
|
||||
packing=packing,
|
||||
upcast_layernorm=upcast_layernorm,
|
||||
use_llama_pro=use_llama_pro,
|
||||
shift_attn=shift_attn,
|
||||
@@ -166,7 +166,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False,
|
||||
)
|
||||
).then(autoset_packing, [training_stage], [packing], queue=False)
|
||||
|
||||
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
||||
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
|
||||
|
||||
Reference in New Issue
Block a user