support save full model, replace BOS token

Former-commit-id: 32e56c290802ba971c08f471b94a33daec85671a
This commit is contained in:
hiyouga
2023-06-27 21:40:11 +08:00
parent 33c2b063c6
commit 640f774d30
3 changed files with 48 additions and 33 deletions

View File

@@ -103,8 +103,9 @@ def _init_adapter(
lastest_checkpoint = None
if model_args.checkpoint_dir is not None:
if os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)) and \
not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)):
raise ValueError("Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0]))
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
raise ValueError("The given checkpoint may be not a LoRA checkpoint, \
please specify `--finetuning_type full/freeze` instead.")
@@ -449,7 +450,7 @@ def preprocess_data(
yield dialog
def preprocess_pretrain_dataset(examples):
# build grouped texts with format `[BOS] X1 X2 X3 ...` (without [EOS])
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
concatenated_ids = list(chain(*text_ids))
total_length = len(concatenated_ids)
@@ -465,7 +466,7 @@ def preprocess_data(
}
def preprocess_supervised_dataset(examples):
# build inputs with format `X [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] Y [EOS]`
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for input with history, we build multiple input-label pairs just like:
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
model_inputs = {"input_ids": [], "labels": []}
@@ -475,15 +476,26 @@ def preprocess_data(
for i in range(len(dialog) // 2):
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False)
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
input_ids += source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
if len(source_ids) > data_args.max_source_length - 1: # bos token
source_ids = source_ids[:data_args.max_source_length - 1]
if len(target_ids) > data_args.max_target_length - 1: # eos token
target_ids = target_ids[:data_args.max_target_length - 1]
input_ids += [tokenizer.bos_token_id] + source_ids + target_ids + [tokenizer.eos_token_id]
labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids[:data_args.max_source_length + data_args.max_target_length])
model_inputs["labels"].append(labels[:data_args.max_source_length + data_args.max_target_length])
if len(input_ids) > data_args.max_source_length + data_args.max_target_length:
input_ids = input_ids[:data_args.max_source_length + data_args.max_target_length]
if len(labels) > data_args.max_source_length + data_args.max_target_length:
labels = labels[:data_args.max_source_length + data_args.max_target_length]
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_unsupervised_dataset(examples):
# build inputs with format `X [BOS]` and labels with format `Y [BOS]`
# build inputs with format `<bos> X` and labels with format `<bos> Y`
model_inputs = {"input_ids": [], "labels": []}
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
@@ -496,15 +508,15 @@ def preprocess_data(
if len(target_ids) > data_args.max_target_length - 1: # bos token
target_ids = target_ids[:data_args.max_target_length - 1]
input_ids = source_ids + [tokenizer.bos_token_id]
labels = target_ids + [tokenizer.bos_token_id]
input_ids = [tokenizer.bos_token_id] + source_ids
labels = [tokenizer.bos_token_id] + target_ids
model_inputs["input_ids"].append(input_ids)
model_inputs["labels"].append(labels)
return model_inputs
def preprocess_pairwise_dataset(examples):
# build input pairs with format `X [BOS] Y1 [EOS]` and `X [BOS] Y2 [EOS]`
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []}
for dialog in get_dialog(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1]
@@ -520,8 +532,8 @@ def preprocess_data(
if len(reject_ids) > data_args.max_target_length - 1: # eos token
reject_ids = reject_ids[:data_args.max_target_length - 1]
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id]
reject_ids = source_ids + [tokenizer.bos_token_id] + reject_ids + [tokenizer.eos_token_id]
accept_ids = [tokenizer.bos_token_id] + source_ids + accept_ids + [tokenizer.eos_token_id]
reject_ids = [tokenizer.bos_token_id] + source_ids + reject_ids + [tokenizer.eos_token_id]
model_inputs["accept_ids"].append(accept_ids)
model_inputs["reject_ids"].append(reject_ids)