|
|
|
|
@@ -183,6 +183,7 @@ def load_pretrained(
|
|
|
|
|
load_in_8bit=True,
|
|
|
|
|
llm_int8_threshold=6.0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif model_args.quantization_bit == 4:
|
|
|
|
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
|
|
|
|
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
|
|
|
|
|
@@ -195,6 +196,7 @@ def load_pretrained(
|
|
|
|
|
bnb_4bit_use_double_quant=model_args.double_quantization,
|
|
|
|
|
bnb_4bit_quant_type=model_args.quantization_type
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
is_mergeable = False
|
|
|
|
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
|
|
|
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
|
|
|
|
@@ -211,10 +213,20 @@ def load_pretrained(
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
model_to_load,
|
|
|
|
|
config=config,
|
|
|
|
|
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
|
|
|
|
|
torch_dtype=model_args.compute_dtype,
|
|
|
|
|
low_cpu_mem_usage=True,
|
|
|
|
|
**config_kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Register auto class to save the custom code files.
|
|
|
|
|
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
|
|
|
|
|
config.__class__.register_for_auto_class()
|
|
|
|
|
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
|
|
|
|
tokenizer.__class__.register_for_auto_class()
|
|
|
|
|
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
|
|
|
|
|
model.__class__.register_for_auto_class()
|
|
|
|
|
|
|
|
|
|
# Initialize adapters
|
|
|
|
|
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
|
|
|
|
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
|
|
|
|
|
|
|
|
|
@@ -487,49 +499,49 @@ def preprocess_data(
|
|
|
|
|
# for input with history, we build multiple input-label pairs just like:
|
|
|
|
|
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
|
|
|
|
model_inputs = {"input_ids": [], "labels": []}
|
|
|
|
|
max_length = data_args.max_source_length + data_args.max_target_length
|
|
|
|
|
|
|
|
|
|
for dialog in get_dialog(examples):
|
|
|
|
|
input_ids, labels = [], []
|
|
|
|
|
|
|
|
|
|
for i in range(len(dialog) // 2):
|
|
|
|
|
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False)
|
|
|
|
|
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=True)
|
|
|
|
|
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
|
|
|
|
|
|
|
|
|
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
|
|
|
|
source_ids = source_ids[:data_args.max_source_length - 1]
|
|
|
|
|
if len(source_ids) > data_args.max_source_length:
|
|
|
|
|
source_ids = source_ids[:data_args.max_source_length]
|
|
|
|
|
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
|
|
|
|
target_ids = target_ids[:data_args.max_target_length - 1]
|
|
|
|
|
|
|
|
|
|
input_ids += [tokenizer.bos_token_id] + source_ids + target_ids + [tokenizer.eos_token_id]
|
|
|
|
|
labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]
|
|
|
|
|
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if len(input_ids) > data_args.max_source_length + data_args.max_target_length:
|
|
|
|
|
input_ids = input_ids[:data_args.max_source_length + data_args.max_target_length]
|
|
|
|
|
if len(labels) > data_args.max_source_length + data_args.max_target_length:
|
|
|
|
|
labels = labels[:data_args.max_source_length + data_args.max_target_length]
|
|
|
|
|
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
|
|
|
|
|
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
|
|
|
|
|
|
|
|
|
model_inputs["input_ids"].append(input_ids)
|
|
|
|
|
model_inputs["labels"].append(labels)
|
|
|
|
|
|
|
|
|
|
return model_inputs
|
|
|
|
|
|
|
|
|
|
def preprocess_unsupervised_dataset(examples):
|
|
|
|
|
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
|
|
|
|
model_inputs = {"input_ids": [], "labels": []}
|
|
|
|
|
|
|
|
|
|
for dialog in get_dialog(examples):
|
|
|
|
|
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
|
|
|
|
|
|
|
|
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
|
|
|
|
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
|
|
|
|
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
|
|
|
|
target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
|
|
|
|
source_ids = source_ids[:data_args.max_source_length - 1]
|
|
|
|
|
if len(target_ids) > data_args.max_target_length - 1: # bos token
|
|
|
|
|
target_ids = target_ids[:data_args.max_target_length - 1]
|
|
|
|
|
if len(source_ids) > data_args.max_source_length:
|
|
|
|
|
source_ids = source_ids[:data_args.max_source_length]
|
|
|
|
|
if len(target_ids) > data_args.max_target_length:
|
|
|
|
|
target_ids = target_ids[:data_args.max_target_length]
|
|
|
|
|
|
|
|
|
|
input_ids = [tokenizer.bos_token_id] + source_ids
|
|
|
|
|
labels = [tokenizer.bos_token_id] + target_ids
|
|
|
|
|
model_inputs["input_ids"].append(source_ids)
|
|
|
|
|
model_inputs["labels"].append(target_ids)
|
|
|
|
|
|
|
|
|
|
model_inputs["input_ids"].append(input_ids)
|
|
|
|
|
model_inputs["labels"].append(labels)
|
|
|
|
|
return model_inputs
|
|
|
|
|
|
|
|
|
|
def preprocess_pairwise_dataset(examples):
|
|
|
|
|
@@ -538,19 +550,19 @@ def preprocess_data(
|
|
|
|
|
for dialog in get_dialog(examples):
|
|
|
|
|
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
|
|
|
|
|
|
|
|
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
|
|
|
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
|
|
|
|
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
|
|
|
|
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
|
|
|
|
|
|
|
|
|
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
|
|
|
|
source_ids = source_ids[:data_args.max_source_length - 1]
|
|
|
|
|
if len(source_ids) > data_args.max_source_length:
|
|
|
|
|
source_ids = source_ids[:data_args.max_source_length]
|
|
|
|
|
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
|
|
|
|
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
|
|
|
|
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
|
|
|
|
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
|
|
|
|
|
|
|
|
|
accept_ids = [tokenizer.bos_token_id] + source_ids + accept_ids + [tokenizer.eos_token_id]
|
|
|
|
|
reject_ids = [tokenizer.bos_token_id] + source_ids + reject_ids + [tokenizer.eos_token_id]
|
|
|
|
|
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
|
|
|
|
|
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
|
|
|
|
|
|
|
|
|
|
model_inputs["accept_ids"].append(accept_ids)
|
|
|
|
|
model_inputs["reject_ids"].append(reject_ids)
|
|
|
|
|
|