mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-03 08:53:38 +00:00
refactor constants
Former-commit-id: a4d4c3fd35276f20e3b354e9d13ea971029c8775
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
from llmtuner.extras.logging import get_logger
|
||||
@@ -56,7 +56,7 @@ def prepare_model_for_training(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
output_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Includes:
|
||||
|
||||
Reference in New Issue
Block a user