Merge branch 'hiyouga:main' into main
Former-commit-id: c25734d874a36222e0a540a2c994bbda73008b27
This commit is contained in:
@@ -701,17 +701,8 @@ _register_template(
|
||||
_register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
||||
default_system=(
|
||||
"You are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -35,6 +35,8 @@ IGNORE_INDEX = -100
|
||||
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
||||
@@ -47,10 +49,10 @@ SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
|
||||
SUPPORTED_MODELS = OrderedDict()
|
||||
|
||||
TRAINER_CONFIG = "trainer_config.yaml"
|
||||
|
||||
TRAINER_LOG = "trainer_log.jsonl"
|
||||
|
||||
TRAINING_ARGS = "training_args.yaml"
|
||||
|
||||
TRAINING_STAGES = {
|
||||
"Supervised Fine-Tuning": "sft",
|
||||
"Reward Modeling": "rm",
|
||||
|
||||
@@ -6,13 +6,10 @@ import peft
|
||||
import torch
|
||||
import transformers
|
||||
import trl
|
||||
from transformers.integrations import is_deepspeed_available
|
||||
from transformers.utils import is_bitsandbytes_available, is_torch_cuda_available, is_torch_npu_available
|
||||
|
||||
from .packages import is_vllm_available
|
||||
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
||||
|
||||
|
||||
VERSION = "0.7.2.dev0"
|
||||
VERSION = "0.8.1.dev0"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
@@ -37,19 +34,25 @@ def print_env() -> None:
|
||||
info["NPU type"] = torch.npu.get_device_name()
|
||||
info["CANN version"] = torch.version.cann
|
||||
|
||||
if is_deepspeed_available():
|
||||
try:
|
||||
import deepspeed # type: ignore
|
||||
|
||||
info["DeepSpeed version"] = deepspeed.__version__
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
try:
|
||||
import bitsandbytes
|
||||
|
||||
info["Bitsandbytes version"] = bitsandbytes.__version__
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if is_vllm_available():
|
||||
try:
|
||||
import vllm
|
||||
|
||||
info["vLLM version"] = vllm.__version__
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
|
||||
|
||||
@@ -15,7 +15,12 @@ class ModelArguments:
|
||||
)
|
||||
adapter_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."},
|
||||
metadata={
|
||||
"help": (
|
||||
"Path to the adapter weight or identifier from huggingface.co/models. "
|
||||
"Use commas to separate multiple adapters."
|
||||
)
|
||||
},
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -35,7 +40,7 @@ class ModelArguments:
|
||||
)
|
||||
new_special_tokens: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Special tokens to be added into the tokenizer."},
|
||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
|
||||
@@ -21,6 +21,218 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _setup_full_tuning(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
) -> None:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
forbidden_modules = set()
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
|
||||
forbidden_modules.add("language_model")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||
if cast_trainable_params_to_fp32:
|
||||
param.data = param.data.to(torch.float32)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
|
||||
def _setup_freeze_tuning(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
) -> None:
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
if model_args.visual_inputs:
|
||||
config = model.config.text_config
|
||||
else:
|
||||
config = model.config
|
||||
|
||||
num_layers = (
|
||||
getattr(config, "num_hidden_layers", None)
|
||||
or getattr(config, "num_layers", None)
|
||||
or getattr(config, "n_layer", None)
|
||||
)
|
||||
if not num_layers:
|
||||
raise ValueError("Current model does not support freeze tuning.")
|
||||
|
||||
if finetuning_args.use_llama_pro:
|
||||
if num_layers % finetuning_args.freeze_trainable_layers != 0:
|
||||
raise ValueError(
|
||||
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
|
||||
num_layers, finetuning_args.freeze_trainable_layers
|
||||
)
|
||||
)
|
||||
|
||||
stride = num_layers // finetuning_args.freeze_trainable_layers
|
||||
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||
elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers)
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers))
|
||||
|
||||
hidden_modules = set()
|
||||
non_hidden_modules = set()
|
||||
for name, _ in model.named_parameters():
|
||||
if ".0." in name:
|
||||
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
|
||||
elif ".1." in name: # MoD starts from layer 1
|
||||
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
|
||||
|
||||
if re.search(r"\.\d+\.", name) is None:
|
||||
non_hidden_modules.add(name.split(".")[-2])
|
||||
|
||||
trainable_layers = []
|
||||
for module_name in finetuning_args.freeze_trainable_modules:
|
||||
if module_name != "all" and module_name not in hidden_modules:
|
||||
raise ValueError(
|
||||
"Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules))
|
||||
)
|
||||
|
||||
for idx in trainable_layer_ids:
|
||||
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
|
||||
|
||||
if finetuning_args.freeze_extra_modules:
|
||||
for module_name in finetuning_args.freeze_extra_modules:
|
||||
if module_name not in non_hidden_modules:
|
||||
raise ValueError(
|
||||
"Module {} is not found, please choose from {}".format(module_name, ", ".join(non_hidden_modules))
|
||||
)
|
||||
|
||||
trainable_layers.append(module_name)
|
||||
|
||||
forbidden_modules = set()
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
|
||||
forbidden_module in name for forbidden_module in forbidden_modules
|
||||
):
|
||||
if cast_trainable_params_to_fp32:
|
||||
param.data = param.data.to(torch.float32)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
|
||||
|
||||
|
||||
def _setup_lora_tuning(
|
||||
config: "PretrainedConfig",
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
) -> "PeftModel":
|
||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
adapter_to_resume = None
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
is_mergeable = True
|
||||
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
|
||||
is_mergeable = False
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
|
||||
is_mergeable = False
|
||||
|
||||
if model_args.use_unsloth:
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
|
||||
is_mergeable = False
|
||||
|
||||
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
|
||||
adapter_to_merge = model_args.adapter_name_or_path[:-1]
|
||||
adapter_to_resume = model_args.adapter_name_or_path[-1]
|
||||
else:
|
||||
adapter_to_merge = model_args.adapter_name_or_path
|
||||
|
||||
for adapter in adapter_to_merge:
|
||||
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, offload_folder=model_args.offload_folder)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(adapter_to_merge) > 0:
|
||||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
if model_args.use_unsloth:
|
||||
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
|
||||
else:
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
adapter_to_resume,
|
||||
is_trainable=is_trainable,
|
||||
offload_folder=model_args.offload_folder,
|
||||
)
|
||||
|
||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
|
||||
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
||||
|
||||
if (
|
||||
finetuning_args.use_dora
|
||||
and getattr(model, "quantization_method", None) is not None
|
||||
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
|
||||
):
|
||||
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
||||
|
||||
if model_args.resize_vocab and finetuning_args.additional_target is None:
|
||||
input_embeddings = model.get_input_embeddings()
|
||||
output_embeddings = model.get_output_embeddings()
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if module in [input_embeddings, output_embeddings]:
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
finetuning_args.additional_target = module_names
|
||||
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
"target_modules": target_modules,
|
||||
"lora_alpha": finetuning_args.lora_alpha,
|
||||
"lora_dropout": finetuning_args.lora_dropout,
|
||||
"use_rslora": finetuning_args.use_rslora,
|
||||
"modules_to_save": finetuning_args.additional_target,
|
||||
}
|
||||
|
||||
if model_args.use_unsloth:
|
||||
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
|
||||
else:
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
use_dora=finetuning_args.use_dora,
|
||||
**peft_kwargs,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if is_trainable and cast_trainable_params_to_fp32:
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def init_adapter(
|
||||
config: "PretrainedConfig",
|
||||
model: "PreTrainedModel",
|
||||
@@ -35,7 +247,6 @@ def init_adapter(
|
||||
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
|
||||
if (not is_trainable) and model_args.adapter_name_or_path is None:
|
||||
logger.info("Adapter is not found at evaluation, load the base model.")
|
||||
return model
|
||||
@@ -50,200 +261,15 @@ def init_adapter(
|
||||
logger.info("Upcasting trainable params to float32.")
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
if is_trainable and finetuning_args.finetuning_type == "full":
|
||||
_setup_full_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
|
||||
|
||||
forbidden_modules = set()
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
if model_args.visual_inputs and finetuning_args.train_mm_proj_only:
|
||||
forbidden_modules.add("language_model")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||
if cast_trainable_params_to_fp32:
|
||||
param.data = param.data.to(torch.float32)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
|
||||
if model_args.visual_inputs:
|
||||
config = model.config.text_config
|
||||
else:
|
||||
config = model.config
|
||||
|
||||
num_layers = (
|
||||
getattr(config, "num_hidden_layers", None)
|
||||
or getattr(config, "num_layers", None)
|
||||
or getattr(config, "n_layer", None)
|
||||
)
|
||||
if not num_layers:
|
||||
raise ValueError("Current model does not support freeze tuning.")
|
||||
|
||||
if finetuning_args.use_llama_pro:
|
||||
if num_layers % finetuning_args.freeze_trainable_layers != 0:
|
||||
raise ValueError(
|
||||
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
|
||||
num_layers, finetuning_args.freeze_trainable_layers
|
||||
)
|
||||
)
|
||||
|
||||
stride = num_layers // finetuning_args.freeze_trainable_layers
|
||||
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||
elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers)
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers))
|
||||
|
||||
hidden_modules = set()
|
||||
non_hidden_modules = set()
|
||||
for name, _ in model.named_parameters():
|
||||
if ".0." in name:
|
||||
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
|
||||
elif ".1." in name: # MoD starts from layer 1
|
||||
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
|
||||
|
||||
if re.search(r"\.\d+\.", name) is None:
|
||||
non_hidden_modules.add(name.split(".")[-2])
|
||||
|
||||
trainable_layers = []
|
||||
for module_name in finetuning_args.freeze_trainable_modules:
|
||||
if module_name != "all" and module_name not in hidden_modules:
|
||||
raise ValueError(
|
||||
"Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules))
|
||||
)
|
||||
|
||||
for idx in trainable_layer_ids:
|
||||
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
|
||||
|
||||
if finetuning_args.freeze_extra_modules:
|
||||
for module_name in finetuning_args.freeze_extra_modules:
|
||||
if module_name not in non_hidden_modules:
|
||||
raise ValueError(
|
||||
"Module {} is not found, please choose from {}".format(
|
||||
module_name, ", ".join(non_hidden_modules)
|
||||
)
|
||||
)
|
||||
|
||||
trainable_layers.append(module_name)
|
||||
|
||||
forbidden_modules = set()
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
|
||||
forbidden_module in name for forbidden_module in forbidden_modules
|
||||
):
|
||||
if cast_trainable_params_to_fp32:
|
||||
param.data = param.data.to(torch.float32)
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
if is_trainable and finetuning_args.finetuning_type == "freeze":
|
||||
_setup_freeze_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32)
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
adapter_to_resume = None
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
is_mergeable = True
|
||||
if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
|
||||
is_mergeable = False
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
|
||||
is_mergeable = False
|
||||
|
||||
if model_args.use_unsloth:
|
||||
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter."
|
||||
is_mergeable = False
|
||||
|
||||
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
|
||||
adapter_to_merge = model_args.adapter_name_or_path[:-1]
|
||||
adapter_to_resume = model_args.adapter_name_or_path[-1]
|
||||
else:
|
||||
adapter_to_merge = model_args.adapter_name_or_path
|
||||
|
||||
for adapter in adapter_to_merge:
|
||||
model: "LoraModel" = PeftModel.from_pretrained(
|
||||
model, adapter, offload_folder=model_args.offload_folder
|
||||
)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(adapter_to_merge) > 0:
|
||||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
if model_args.use_unsloth:
|
||||
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
|
||||
else:
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
adapter_to_resume,
|
||||
is_trainable=is_trainable,
|
||||
offload_folder=model_args.offload_folder,
|
||||
)
|
||||
|
||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
|
||||
|
||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||
target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
||||
|
||||
if (
|
||||
finetuning_args.use_dora
|
||||
and getattr(model, "quantization_method", None) is not None
|
||||
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
|
||||
):
|
||||
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
||||
|
||||
if model_args.resize_vocab and finetuning_args.additional_target is None:
|
||||
input_embeddings = model.get_input_embeddings()
|
||||
output_embeddings = model.get_output_embeddings()
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if module in [input_embeddings, output_embeddings]:
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
finetuning_args.additional_target = module_names
|
||||
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
"target_modules": target_modules,
|
||||
"lora_alpha": finetuning_args.lora_alpha,
|
||||
"lora_dropout": finetuning_args.lora_dropout,
|
||||
"use_rslora": finetuning_args.use_rslora,
|
||||
"modules_to_save": finetuning_args.additional_target,
|
||||
}
|
||||
|
||||
if model_args.use_unsloth:
|
||||
model = get_unsloth_peft_model(model, model_args, peft_kwargs)
|
||||
else:
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
use_dora=finetuning_args.use_dora,
|
||||
**peft_kwargs,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
if cast_trainable_params_to_fp32:
|
||||
for param in filter(lambda p: p.requires_grad, model.parameters()):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
model = _setup_lora_tuning(
|
||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
@@ -50,13 +50,6 @@ def get_config_path() -> os.PathLike:
|
||||
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
||||
|
||||
|
||||
def get_arg_save_path(config_path: str) -> os.PathLike:
|
||||
r"""
|
||||
Gets the path to saved arguments.
|
||||
"""
|
||||
return os.path.join(DEFAULT_CONFIG_DIR, config_path)
|
||||
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
r"""
|
||||
Loads user config if exists.
|
||||
@@ -77,24 +70,28 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
||||
user_config["lang"] = lang or user_config["lang"]
|
||||
if model_name:
|
||||
user_config["last_model"] = model_name
|
||||
|
||||
if model_name and model_path:
|
||||
user_config["path_dict"][model_name] = model_path
|
||||
|
||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||
safe_dump(user_config, f)
|
||||
|
||||
|
||||
def get_model_path(model_name: str) -> Optional[str]:
|
||||
def get_model_path(model_name: str) -> str:
|
||||
r"""
|
||||
Gets the model path according to the model name.
|
||||
"""
|
||||
user_config = load_config()
|
||||
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
||||
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
|
||||
path_dict: Dict["DownloadSource", str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
||||
model_path = user_config["path_dict"].get(model_name, "") or path_dict.get(DownloadSource.DEFAULT, "")
|
||||
if (
|
||||
use_modelscope()
|
||||
and path_dict.get(DownloadSource.MODELSCOPE)
|
||||
and model_path == path_dict.get(DownloadSource.DEFAULT)
|
||||
): # replace path
|
||||
model_path = path_dict.get(DownloadSource.MODELSCOPE)
|
||||
|
||||
return model_path
|
||||
|
||||
|
||||
|
||||
@@ -36,7 +36,8 @@ def create_top() -> Dict[str, "Component"]:
|
||||
visual_inputs = gr.Checkbox(scale=1)
|
||||
|
||||
model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False)
|
||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
model_name.input(save_config, inputs=[lang, model_name], queue=False)
|
||||
model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False)
|
||||
checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from ...extras.constants import TRAINING_STAGES
|
||||
from ...extras.misc import get_device_count
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ..common import DEFAULT_DATA_DIR, list_checkpoints, list_datasets
|
||||
from ..utils import change_stage, check_output_dir, list_config_paths, list_output_dirs
|
||||
from ..utils import change_stage, list_config_paths, list_output_dirs
|
||||
from .data import create_preview_box
|
||||
|
||||
|
||||
@@ -319,7 +319,13 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
finetuning_type.change(list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], queue=False)
|
||||
output_dir.change(
|
||||
list_output_dirs, [model_name, finetuning_type, current_time], [output_dir], concurrency_limit=None
|
||||
).then(check_output_dir, inputs=[lang, model_name, finetuning_type, output_dir], concurrency_limit=None)
|
||||
)
|
||||
output_dir.input(
|
||||
engine.runner.check_output_dir,
|
||||
[lang, model_name, finetuning_type, output_dir],
|
||||
list(input_elems) + [output_box],
|
||||
concurrency_limit=None,
|
||||
)
|
||||
config_path.change(list_config_paths, [current_time], [config_path], queue=False)
|
||||
|
||||
return elem_dict
|
||||
|
||||
@@ -5,11 +5,11 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
|
||||
from ..extras.constants import PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import DEFAULT_CACHE_DIR, get_save_dir, load_config
|
||||
from .locales import ALERTS
|
||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
|
||||
from .locales import ALERTS, LOCALES
|
||||
from .utils import abort_leaf_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
||||
|
||||
|
||||
@@ -276,6 +276,10 @@ class Runner:
|
||||
else:
|
||||
self.do_train, self.running_data = do_train, data
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
|
||||
os.makedirs(args["output_dir"], exist_ok=True)
|
||||
save_args(os.path.join(args["output_dir"], LLAMABOARD_CONFIG), self._form_config_dict(data))
|
||||
|
||||
env = deepcopy(os.environ)
|
||||
env["LLAMABOARD_ENABLED"] = "1"
|
||||
if args.get("deepspeed", None) is not None:
|
||||
@@ -284,6 +288,16 @@ class Runner:
|
||||
self.trainer = Popen("llamafactory-cli train {}".format(save_cmd(args)), env=env, shell=True)
|
||||
yield from self.monitor()
|
||||
|
||||
def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
|
||||
config_dict = {}
|
||||
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"]
|
||||
for elem, value in data.items():
|
||||
elem_id = self.manager.get_id_by_elem(elem)
|
||||
if elem_id not in skip_ids:
|
||||
config_dict[elem_id] = value
|
||||
|
||||
return config_dict
|
||||
|
||||
def preview_train(self, data):
|
||||
yield from self._preview(data, do_train=True)
|
||||
|
||||
@@ -349,28 +363,24 @@ class Runner:
|
||||
}
|
||||
yield return_dict
|
||||
|
||||
def save_args(self, data: dict):
|
||||
def save_args(self, data):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
error = self._initialize(data, do_train=True, from_preview=True)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
return {output_box: error}
|
||||
|
||||
config_dict: Dict[str, Any] = {}
|
||||
lang = data[self.manager.get_elem_by_id("top.lang")]
|
||||
config_path = data[self.manager.get_elem_by_id("train.config_path")]
|
||||
skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path", "train.device_count"]
|
||||
for elem, value in data.items():
|
||||
elem_id = self.manager.get_id_by_elem(elem)
|
||||
if elem_id not in skip_ids:
|
||||
config_dict[elem_id] = value
|
||||
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
|
||||
save_path = os.path.join(DEFAULT_CONFIG_DIR, config_path)
|
||||
|
||||
save_path = save_args(config_path, config_dict)
|
||||
save_args(save_path, self._form_config_dict(data))
|
||||
return {output_box: ALERTS["info_config_saved"][lang] + save_path}
|
||||
|
||||
def load_args(self, lang: str, config_path: str):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
config_dict = load_args(config_path)
|
||||
config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path))
|
||||
if config_dict is None:
|
||||
gr.Warning(ALERTS["err_config_not_found"][lang])
|
||||
return {output_box: ALERTS["err_config_not_found"][lang]}
|
||||
@@ -380,3 +390,17 @@ class Runner:
|
||||
output_dict[self.manager.get_elem_by_id(elem_id)] = value
|
||||
|
||||
return output_dict
|
||||
|
||||
def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str):
|
||||
output_box = self.manager.get_elem_by_id("train.output_box")
|
||||
output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]}
|
||||
if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
|
||||
gr.Warning(ALERTS["warn_output_dir_exists"][lang])
|
||||
output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang]
|
||||
|
||||
output_dir = get_save_dir(model_name, finetuning_type, output_dir)
|
||||
config_dict = load_args(os.path.join(output_dir, LLAMABOARD_CONFIG)) # load llamaboard config
|
||||
for elem_id, value in config_dict.items():
|
||||
output_dict[self.manager.get_elem_by_id(elem_id)] = value
|
||||
|
||||
return output_dict
|
||||
|
||||
@@ -8,10 +8,10 @@ import psutil
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from yaml import safe_dump, safe_load
|
||||
|
||||
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_CONFIG, TRAINER_LOG, TRAINING_STAGES
|
||||
from ..extras.constants import PEFT_METHODS, RUNNING_LOG, TRAINER_LOG, TRAINING_ARGS, TRAINING_STAGES
|
||||
from ..extras.packages import is_gradio_available, is_matplotlib_available
|
||||
from ..extras.ploting import gen_loss_plot
|
||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_arg_save_path, get_save_dir
|
||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
@@ -93,10 +93,10 @@ def save_cmd(args: Dict[str, Any]) -> str:
|
||||
output_dir = args["output_dir"]
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
with open(os.path.join(output_dir, TRAINER_CONFIG), "w", encoding="utf-8") as f:
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS), "w", encoding="utf-8") as f:
|
||||
safe_dump(clean_cmd(args), f)
|
||||
|
||||
return os.path.join(output_dir, TRAINER_CONFIG)
|
||||
return os.path.join(output_dir, TRAINING_ARGS)
|
||||
|
||||
|
||||
def get_eval_results(path: os.PathLike) -> str:
|
||||
@@ -157,22 +157,19 @@ def load_args(config_path: str) -> Optional[Dict[str, Any]]:
|
||||
Loads saved arguments.
|
||||
"""
|
||||
try:
|
||||
with open(get_arg_save_path(config_path), "r", encoding="utf-8") as f:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
return safe_load(f)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
|
||||
def save_args(config_path: str, config_dict: Dict[str, Any]):
|
||||
r"""
|
||||
Saves arguments.
|
||||
"""
|
||||
os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
|
||||
with open(get_arg_save_path(config_path), "w", encoding="utf-8") as f:
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
safe_dump(config_dict, f)
|
||||
|
||||
return str(get_arg_save_path(config_path))
|
||||
|
||||
|
||||
def list_config_paths(current_time: str) -> "gr.Dropdown":
|
||||
r"""
|
||||
@@ -181,13 +178,13 @@ def list_config_paths(current_time: str) -> "gr.Dropdown":
|
||||
config_files = ["{}.yaml".format(current_time)]
|
||||
if os.path.isdir(DEFAULT_CONFIG_DIR):
|
||||
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
|
||||
if file_name.endswith(".yaml"):
|
||||
if file_name.endswith(".yaml") and file_name not in config_files:
|
||||
config_files.append(file_name)
|
||||
|
||||
return gr.Dropdown(choices=config_files)
|
||||
|
||||
|
||||
def list_output_dirs(model_name: str, finetuning_type: str, current_time: str) -> "gr.Dropdown":
|
||||
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
|
||||
r"""
|
||||
Lists all the directories that can resume from.
|
||||
"""
|
||||
@@ -203,14 +200,6 @@ def list_output_dirs(model_name: str, finetuning_type: str, current_time: str) -
|
||||
return gr.Dropdown(choices=output_dirs)
|
||||
|
||||
|
||||
def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None:
|
||||
r"""
|
||||
Check if output dir exists.
|
||||
"""
|
||||
if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
|
||||
gr.Warning(ALERTS["warn_output_dir_exists"][lang])
|
||||
|
||||
|
||||
def create_ds_config() -> None:
|
||||
r"""
|
||||
Creates deepspeed config.
|
||||
|
||||
Reference in New Issue
Block a user