move configure_packing to llamafactory.model.patcher and fix constants

Former-commit-id: 9c5e972c9c81957f2e9e30bf284ef1c076de9fd0
This commit is contained in:
ancv
2024-06-21 00:45:06 +07:00
parent dd7a1dbfae
commit 6c185a2c57
6 changed files with 16 additions and 12 deletions

View File

@@ -19,13 +19,13 @@ from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
from .model_utils.packing import configure_packing
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
from ..hparams import ModelArguments, DataArguments, FinetuningArguments
logger = get_logger(__name__)
@@ -40,6 +40,8 @@ def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
data_args: "DataArguments",
finetune_args: "FinetuningArguments",
init_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
@@ -81,6 +83,9 @@ def patch_config(
if init_kwargs["device_map"] == "auto":
init_kwargs["offload_folder"] = model_args.offload_folder
if finetune_args.stage == "sft" and data_args.efficient_packing:
configure_packing(config, model_args)
def patch_model(