support llama pro #2338 , add rslora

Former-commit-id: 40d659b7f30dd5a004703c176ec1f22dc864e505
This commit is contained in:
hiyouga
2024-02-15 02:27:36 +08:00
parent b403f8d8a8
commit 596b6828cb
24 changed files with 438 additions and 203 deletions

View File

@@ -4,7 +4,7 @@ import torch
from ..extras.logging import get_logger
from ..hparams import FinetuningArguments, ModelArguments
from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params
from ..model import load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING:
@@ -25,14 +25,18 @@ def create_modelcard_and_push(
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
) -> None:
if training_args.do_train:
if training_args.push_to_hub:
trainer.push_to_hub(**get_modelcard_args(model_args, data_args, finetuning_args))
return
try:
trainer.create_model_card(**get_modelcard_args(model_args, data_args, finetuning_args))
except Exception as err:
logger.warning("Failed to create model card: {}".format(str(err)))
kwargs = {
"tasks": "text-generation",
"finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory", finetuning_args.finetuning_type],
}
if not training_args.do_train:
pass
elif training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub
def create_ref_model(