better llamaboard

* easily resume from checkpoint
* support full and freeze checkpoints
* faster ui


Former-commit-id: 84cfb2452cc86b037ccddee6e833f8eb7c129fa4
This commit is contained in:
hiyouga
2024-05-29 23:55:38 +08:00
parent f90c4ca672
commit 87aa332583
14 changed files with 303 additions and 193 deletions

View File

@@ -2,6 +2,19 @@ from collections import OrderedDict, defaultdict
from enum import Enum
from typing import Dict, Optional
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
CHECKPOINT_NAMES = {
SAFE_ADAPTER_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
}
CHOICES = ["A", "B", "C", "D"]
@@ -26,9 +39,9 @@ LAYERNORM_NAMES = {"norm", "ln"}
METHODS = ["full", "freeze", "lora"]
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
PEFT_METHODS = ["lora"]
PEFT_METHODS = {"lora"}
RUNNING_LOG = "running_log.txt"
@@ -49,9 +62,9 @@ TRAINING_STAGES = {
"Pre-Training": "pt",
}
STAGES_USE_PAIR_DATA = ["rm", "dpo", "orpo"]
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
SUPPORTED_CLASS_FOR_S2ATTN = ["llama"]
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
V_HEAD_WEIGHTS_NAME = "value_head.bin"