allow non-packing pretraining
Former-commit-id: 3fee5cc5a3db9ce874ad90f2500ec092d904bd4e
This commit is contained in:
@@ -2,7 +2,7 @@ import logging
|
||||
import os
|
||||
import time
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
|
||||
|
||||
import gradio as gr
|
||||
import transformers
|
||||
@@ -25,7 +25,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Runner:
|
||||
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
|
||||
def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
|
||||
self.manager = manager
|
||||
self.demo_mode = demo_mode
|
||||
""" Resume """
|
||||
@@ -136,7 +136,7 @@ class Runner:
|
||||
neftune_noise_alpha=get("train.neftune_alpha") or None,
|
||||
optim=get("train.optim"),
|
||||
resize_vocab=get("train.resize_vocab"),
|
||||
sft_packing=get("train.sft_packing"),
|
||||
packing=get("train.packing"),
|
||||
upcast_layernorm=get("train.upcast_layernorm"),
|
||||
use_llama_pro=get("train.use_llama_pro"),
|
||||
shift_attn=get("train.shift_attn"),
|
||||
|
||||
Reference in New Issue
Block a user