allow non-packing pretraining
Former-commit-id: 3fee5cc5a3db9ce874ad90f2500ec092d904bd4e
This commit is contained in:
@@ -18,9 +18,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
def __init__(
|
||||
self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True
|
||||
) -> None:
|
||||
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
self.engine: Optional["BaseEngine"] = None
|
||||
|
||||
@@ -104,10 +104,12 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
return {}
|
||||
|
||||
|
||||
def list_dataset(
|
||||
dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0]
|
||||
) -> Dict[str, Any]:
|
||||
def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]:
|
||||
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
||||
ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
|
||||
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
|
||||
return gr.update(value=[], choices=datasets)
|
||||
|
||||
|
||||
def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]:
|
||||
return gr.update(value=(TRAINING_STAGES[training_stage] == "pt"))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def create_chat_box(
|
||||
engine: "Engine", visible: Optional[bool] = False
|
||||
engine: "Engine", visible: bool = False
|
||||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Box(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
def create_top() -> Dict[str, "Component"]:
|
||||
def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
with gr.Row():
|
||||
@@ -44,7 +44,7 @@ def create_top() -> Dict[str, "Component"]:
|
||||
|
||||
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
|
||||
|
||||
return dict(
|
||||
return lang, dict(
|
||||
lang=lang,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
|
||||
@@ -4,7 +4,7 @@ import gradio as gr
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
from ...extras.constants import TRAINING_STAGES
|
||||
from ..common import DEFAULT_DATA_DIR, list_adapters, list_dataset
|
||||
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
|
||||
from ..components.data import create_preview_box
|
||||
from ..utils import gen_plot
|
||||
|
||||
@@ -78,7 +78,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Row():
|
||||
resize_vocab = gr.Checkbox()
|
||||
sft_packing = gr.Checkbox()
|
||||
packing = gr.Checkbox()
|
||||
upcast_layernorm = gr.Checkbox()
|
||||
use_llama_pro = gr.Checkbox()
|
||||
shift_attn = gr.Checkbox()
|
||||
@@ -91,7 +91,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
neftune_alpha,
|
||||
optim,
|
||||
resize_vocab,
|
||||
sft_packing,
|
||||
packing,
|
||||
upcast_layernorm,
|
||||
use_llama_pro,
|
||||
shift_attn,
|
||||
@@ -106,7 +106,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
neftune_alpha=neftune_alpha,
|
||||
optim=optim,
|
||||
resize_vocab=resize_vocab,
|
||||
sft_packing=sft_packing,
|
||||
packing=packing,
|
||||
upcast_layernorm=upcast_layernorm,
|
||||
use_llama_pro=use_llama_pro,
|
||||
shift_attn=shift_attn,
|
||||
@@ -166,7 +166,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False,
|
||||
)
|
||||
).then(autoset_packing, [training_stage], [packing], queue=False)
|
||||
|
||||
input_elems.update({dpo_beta, dpo_ftx, reward_model})
|
||||
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
from typing import Any, Dict, Generator
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
@@ -12,7 +12,7 @@ from .utils import get_time
|
||||
|
||||
|
||||
class Engine:
|
||||
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None:
|
||||
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
|
||||
self.demo_mode = demo_mode
|
||||
self.pure_chat = pure_chat
|
||||
self.manager = Manager()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
@@ -19,7 +17,7 @@ from .engine import Engine
|
||||
require_version("gradio>=3.38.0,<4.0.0", 'To fix: pip install "gradio>=3.38.0,<4.0.0"')
|
||||
|
||||
|
||||
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
|
||||
def create_ui(demo_mode: bool = False) -> gr.Blocks:
|
||||
engine = Engine(demo_mode=demo_mode, pure_chat=False)
|
||||
|
||||
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
||||
@@ -31,8 +29,7 @@ def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
|
||||
)
|
||||
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
|
||||
|
||||
engine.manager.all_elems["top"] = create_top()
|
||||
lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
|
||||
lang, engine.manager.all_elems["top"] = create_top()
|
||||
|
||||
with gr.Tab("Train"):
|
||||
engine.manager.all_elems["train"] = create_train_tab(engine)
|
||||
|
||||
@@ -480,18 +480,18 @@ LOCALES = {
|
||||
"info": "更改分词器词表和嵌入层的大小。",
|
||||
},
|
||||
},
|
||||
"sft_packing": {
|
||||
"packing": {
|
||||
"en": {
|
||||
"label": "Pack sequences",
|
||||
"info": "Pack sequences into samples of fixed length in supervised fine-tuning.",
|
||||
"info": "Pack sequences into samples of fixed length.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Упаковка последовательностей",
|
||||
"info": "Упаковка последовательностей в образцы фиксированной длины при контролируемой тонкой настройке.",
|
||||
"info": "Упаковка последовательностей в образцы фиксированной длины.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "序列打包",
|
||||
"info": "在指令监督微调时将序列打包为等长样本。",
|
||||
"info": "将序列打包为等长样本。",
|
||||
},
|
||||
},
|
||||
"upcast_layernorm": {
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import time
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
|
||||
|
||||
import gradio as gr
|
||||
import transformers
|
||||
@@ -25,7 +25,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Runner:
|
||||
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
|
||||
def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
""" Resume """
|
||||
@@ -136,7 +136,7 @@ class Runner:
|
||||
neftune_noise_alpha=get("train.neftune_alpha") or None,
|
||||
optim=get("train.optim"),
|
||||
resize_vocab=get("train.resize_vocab"),
|
||||
sft_packing=get("train.sft_packing"),
|
||||
packing=get("train.packing"),
|
||||
upcast_layernorm=get("train.upcast_layernorm"),
|
||||
use_llama_pro=get("train.use_llama_pro"),
|
||||
shift_attn=get("train.shift_attn"),
|
||||
|
||||
Reference in New Issue
Block a user