support RM metrics, add generating Args
Former-commit-id: c461c6190bc124e98dde7f3cf96a59ce40b26fb0
This commit is contained in:
@@ -36,7 +36,8 @@ from trl import AutoModelForCausalLMWithValueHead
|
||||
from .config import (
|
||||
ModelArguments,
|
||||
DataTrainingArguments,
|
||||
FinetuningArguments
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
)
|
||||
|
||||
from .template import Template
|
||||
@@ -54,7 +55,8 @@ check_min_version("4.29.1")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
|
||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
||||
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
|
||||
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -91,12 +93,10 @@ def _init_adapter(
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
assert is_mergeable and len(
|
||||
model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
|
||||
assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
|
||||
else:
|
||||
assert is_mergeable or len(
|
||||
model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
|
||||
assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
@@ -106,8 +106,7 @@ def _init_adapter(
|
||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
||||
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||
|
||||
if (is_trainable and model_args.resume_lora_training) or (
|
||||
not is_mergeable): # continually train on the lora weights
|
||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
||||
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
else:
|
||||
checkpoints_to_merge = model_args.checkpoint_dir
|
||||
@@ -119,10 +118,10 @@ def _init_adapter(
|
||||
if len(checkpoints_to_merge) > 0:
|
||||
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
||||
|
||||
if lastest_checkpoint is not None: # resume lora training or quantized inference
|
||||
if lastest_checkpoint is not None: # resume lora training or quantized inference
|
||||
model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable)
|
||||
|
||||
if is_trainable and lastest_checkpoint is None: # create new lora weights while training
|
||||
if is_trainable and lastest_checkpoint is None: # create new lora weights while training
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
@@ -170,7 +169,7 @@ def load_pretrained(
|
||||
padding_side="left",
|
||||
**config_kwargs
|
||||
)
|
||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
||||
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
is_mergeable = True
|
||||
@@ -186,11 +185,9 @@ def load_pretrained(
|
||||
)
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
require_version("transformers>=4.30.0.dev0",
|
||||
"To fix: pip install git+https://github.com/huggingface/transformers.git")
|
||||
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
|
||||
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
|
||||
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
||||
require_version("accelerate>=0.20.0.dev0",
|
||||
"To fix: pip install git+https://github.com/huggingface/accelerate.git")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
@@ -201,10 +198,10 @@ def load_pretrained(
|
||||
else:
|
||||
raise NotImplementedError
|
||||
is_mergeable = False
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
if not is_trainable:
|
||||
if not is_trainable: # `device_map=auto` should be used for inference only
|
||||
config_kwargs["device_map"] = "auto"
|
||||
|
||||
# Load and prepare pretrained models (without valuehead).
|
||||
@@ -218,24 +215,26 @@ def load_pretrained(
|
||||
model = prepare_model_for_training(model) if is_trainable else model
|
||||
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
load_valuehead_params(model, model_args.checkpoint_dir[0])
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
})
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
assert is_trainable, "PPO stage cannot be performed at evaluation."
|
||||
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
||||
load_valuehead_params(model, model_args.reward_model)
|
||||
|
||||
# Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model
|
||||
# To meet the compliance requirements of the transformers library
|
||||
if model_args.quantization_bit is not None:
|
||||
model._is_int8_training_enabled = True
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
||||
|
||||
print_trainable_params(model)
|
||||
|
||||
@@ -245,11 +244,11 @@ def load_pretrained(
|
||||
def prepare_args(
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
||||
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(
|
||||
json_file=os.path.abspath(sys.argv[1]))
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
||||
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
@@ -290,7 +289,7 @@ def prepare_args(
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if training_args.fp16:
|
||||
@@ -313,13 +312,14 @@ def prepare_args(
|
||||
return model_args, data_args, training_args, finetuning_args
|
||||
|
||||
|
||||
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]:
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments))
|
||||
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments]:
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
||||
model_args, data_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses()
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
@@ -327,13 +327,14 @@ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, Finetun
|
||||
if data_args.prompt_template == "alpaca":
|
||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||
|
||||
return model_args, data_args, finetuning_args
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def prepare_data(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataTrainingArguments
|
||||
) -> Dataset:
|
||||
|
||||
def checksum(file_path, hash):
|
||||
with open(file_path, "rb") as datafile:
|
||||
binary_data = datafile.read()
|
||||
@@ -342,7 +343,7 @@ def prepare_data(
|
||||
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
||||
|
||||
max_samples = data_args.max_samples
|
||||
all_datasets: List[Dataset] = [] # support multiple datasets
|
||||
all_datasets: List[Dataset] = [] # support multiple datasets
|
||||
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
|
||||
@@ -358,10 +359,12 @@ def prepare_data(
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
|
||||
extension = dataset_attr.file_name.split(".")[-1]
|
||||
|
||||
if dataset_attr.file_sha1 is not None:
|
||||
checksum(data_file, dataset_attr.file_sha1)
|
||||
else:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
|
||||
raw_datasets = load_dataset(
|
||||
extension if extension in ["csv", "json"] else "text",
|
||||
data_files=data_file,
|
||||
@@ -383,11 +386,11 @@ def prepare_data(
|
||||
("query_column", "query"),
|
||||
("response_column", "response"),
|
||||
("history_column", "history")
|
||||
]: # every dataset will have 4 columns same as each other
|
||||
]: # every dataset will have 4 columns same as each other
|
||||
if getattr(dataset_attr, column_name) != target_name:
|
||||
if getattr(dataset_attr, column_name):
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
||||
else: # None or empty string
|
||||
else: # None or empty string
|
||||
dataset = dataset.add_column(target_name, dummy_data)
|
||||
all_datasets.append(dataset)
|
||||
|
||||
@@ -406,6 +409,7 @@ def preprocess_data(
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Dataset:
|
||||
|
||||
column_names = list(dataset.column_names)
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
@@ -442,9 +446,9 @@ def preprocess_data(
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
||||
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
||||
source_ids = source_ids[:data_args.max_source_length - 1]
|
||||
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
||||
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||
|
||||
input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
|
||||
@@ -461,9 +465,9 @@ def preprocess_data(
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
||||
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
||||
source_ids = source_ids[:data_args.max_source_length - 1]
|
||||
if len(target_ids) > data_args.max_target_length - 1: # bos token
|
||||
if len(target_ids) > data_args.max_target_length - 1: # bos token
|
||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||
|
||||
input_ids = source_ids + [tokenizer.bos_token_id]
|
||||
@@ -481,11 +485,11 @@ def preprocess_data(
|
||||
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
||||
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
||||
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
||||
source_ids = source_ids[:data_args.max_source_length - 1]
|
||||
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
||||
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
||||
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
||||
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
||||
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
||||
|
||||
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id]
|
||||
|
||||
Reference in New Issue
Block a user