support loading lora from hub
Former-commit-id: 0b34c962bc3368dca62b18ad6c27a0293c3affa5
This commit is contained in:
@@ -29,7 +29,7 @@ from peft import (
|
||||
get_peft_model
|
||||
)
|
||||
|
||||
from peft.utils import CONFIG_NAME
|
||||
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
@@ -103,8 +103,10 @@ def _init_adapter(
|
||||
lastest_checkpoint = None
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
||||
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||
if os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)) and \
|
||||
not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
|
||||
raise ValueError("The given checkpoint may be not a LoRA checkpoint, \
|
||||
please specify `--finetuning_type full/freeze` instead.")
|
||||
|
||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
||||
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
@@ -170,8 +172,7 @@ def load_pretrained(
|
||||
**config_kwargs
|
||||
)
|
||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
||||
if tokenizer.pad_token_id == 64000:
|
||||
tokenizer.pad_token_id = 0 # for baichuan model (need fix)
|
||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id == 64000 else tokenizer.pad_token_id # for baichuan model (older version)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
is_mergeable = True
|
||||
@@ -212,7 +213,7 @@ def load_pretrained(
|
||||
low_cpu_mem_usage=True,
|
||||
**config_kwargs
|
||||
)
|
||||
model = prepare_model_for_training(model) if is_trainable else model
|
||||
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
||||
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
|
||||
Reference in New Issue
Block a user