support save full model, replace BOS token
Former-commit-id: 32e56c290802ba971c08f471b94a33daec85671a
This commit is contained in:
@@ -103,8 +103,9 @@ def _init_adapter(
|
||||
lastest_checkpoint = None
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)) and \
|
||||
not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
|
||||
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)):
|
||||
raise ValueError("Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0]))
|
||||
if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
|
||||
raise ValueError("The given checkpoint may be not a LoRA checkpoint, \
|
||||
please specify `--finetuning_type full/freeze` instead.")
|
||||
|
||||
@@ -449,7 +450,7 @@ def preprocess_data(
|
||||
yield dialog
|
||||
|
||||
def preprocess_pretrain_dataset(examples):
|
||||
# build grouped texts with format `[BOS] X1 X2 X3 ...` (without [EOS])
|
||||
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
|
||||
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
||||
concatenated_ids = list(chain(*text_ids))
|
||||
total_length = len(concatenated_ids)
|
||||
@@ -465,7 +466,7 @@ def preprocess_data(
|
||||
}
|
||||
|
||||
def preprocess_supervised_dataset(examples):
|
||||
# build inputs with format `X [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] Y [EOS]`
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for input with history, we build multiple input-label pairs just like:
|
||||
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
@@ -475,15 +476,26 @@ def preprocess_data(
|
||||
for i in range(len(dialog) // 2):
|
||||
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False)
|
||||
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
||||
input_ids += source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
|
||||
|
||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
||||
source_ids = source_ids[:data_args.max_source_length - 1]
|
||||
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||
|
||||
input_ids += [tokenizer.bos_token_id] + source_ids + target_ids + [tokenizer.eos_token_id]
|
||||
labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids[:data_args.max_source_length + data_args.max_target_length])
|
||||
model_inputs["labels"].append(labels[:data_args.max_source_length + data_args.max_target_length])
|
||||
if len(input_ids) > data_args.max_source_length + data_args.max_target_length:
|
||||
input_ids = input_ids[:data_args.max_source_length + data_args.max_target_length]
|
||||
if len(labels) > data_args.max_source_length + data_args.max_target_length:
|
||||
labels = labels[:data_args.max_source_length + data_args.max_target_length]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["labels"].append(labels)
|
||||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples):
|
||||
# build inputs with format `X [BOS]` and labels with format `Y [BOS]`
|
||||
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
@@ -496,15 +508,15 @@ def preprocess_data(
|
||||
if len(target_ids) > data_args.max_target_length - 1: # bos token
|
||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||
|
||||
input_ids = source_ids + [tokenizer.bos_token_id]
|
||||
labels = target_ids + [tokenizer.bos_token_id]
|
||||
input_ids = [tokenizer.bos_token_id] + source_ids
|
||||
labels = [tokenizer.bos_token_id] + target_ids
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["labels"].append(labels)
|
||||
return model_inputs
|
||||
|
||||
def preprocess_pairwise_dataset(examples):
|
||||
# build input pairs with format `X [BOS] Y1 [EOS]` and `X [BOS] Y2 [EOS]`
|
||||
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
@@ -520,8 +532,8 @@ def preprocess_data(
|
||||
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
||||
|
||||
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id]
|
||||
reject_ids = source_ids + [tokenizer.bos_token_id] + reject_ids + [tokenizer.eos_token_id]
|
||||
accept_ids = [tokenizer.bos_token_id] + source_ids + accept_ids + [tokenizer.eos_token_id]
|
||||
reject_ids = [tokenizer.bos_token_id] + source_ids + reject_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["accept_ids"].append(accept_ids)
|
||||
model_inputs["reject_ids"].append(reject_ids)
|
||||
|
||||
Reference in New Issue
Block a user