refactor constants

Former-commit-id: a4d4c3fd35276f20e3b354e9d13ea971029c8775
This commit is contained in:
hiyouga
2023-11-10 14:16:10 +08:00
parent 68dd1ef121
commit 178b85ff9a
6 changed files with 243 additions and 83 deletions

View File

@@ -1,5 +1,5 @@
import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from llmtuner.extras.constants import LAYERNORM_NAMES
from llmtuner.extras.logging import get_logger
@@ -56,7 +56,7 @@ def prepare_model_for_training(
finetuning_args: "FinetuningArguments",
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
r"""
Includes: