support RM metrics, add generating Args

Former-commit-id: c461c6190bc124e98dde7f3cf96a59ce40b26fb0
This commit is contained in:
hiyouga
2023-06-12 15:48:48 +08:00
parent 4c5cad9722
commit 4724ae3492
16 changed files with 177 additions and 163 deletions

View File

@@ -1,12 +1,13 @@
import os
import json
import torch
from typing import List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional
from dataclasses import asdict, dataclass, field
@dataclass
class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
file_name: Optional[str] = None
@@ -55,11 +56,11 @@ class ModelArguments:
)
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
default="nf4",
metadata={"help": "Quantization data type to use."}
metadata={"help": "Quantization data type to use in int4 training."}
)
double_quantization: Optional[bool] = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
metadata={"help": "Whether to use double quantization in int4 training or not."}
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
@@ -67,8 +68,7 @@ class ModelArguments:
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={
"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
)
reward_model: Optional[str] = field(
default=None,
@@ -76,8 +76,7 @@ class ModelArguments:
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={
"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
plot_loss: Optional[bool] = field(
default=False,
@@ -85,7 +84,7 @@ class ModelArguments:
)
def __post_init__(self):
if self.checkpoint_dir is not None: # support merging multiple lora weights
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
@@ -147,7 +146,7 @@ class DataTrainingArguments:
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
def __post_init__(self): # support mixing multiple datasets
def __post_init__(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f)
@@ -156,42 +155,25 @@ class DataTrainingArguments:
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
dataset_attrs = []
dataset_attr = None
if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
elif os.path.isfile(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])):
else:
dataset_attr = DatasetAttr(
"file",
file_name=dataset_info[name]["file_name"],
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
)
else:
# Support Directory
for file_name in os.listdir(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])):
path = os.path.join(dataset_info[name]["file_name"], file_name)
dataset_attrs.append(DatasetAttr(
"file",
file_name=path,
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
))
if dataset_attr is not None:
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
else:
for i, dataset_attr in enumerate(dataset_attrs):
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
@dataclass
@@ -228,22 +210,20 @@ class FinetuningArguments:
lora_target: Optional[str] = field(
default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \
BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"}
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"gate_proj\", \"down_proj\"], \
BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"]"}
)
def __post_init__(self):
if isinstance(self.lora_target, str):
self.lora_target = [target.strip() for target in
self.lora_target.split(",")] # support custom target modules of LoRA
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in
trainable_layer_ids]
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
@@ -259,3 +239,44 @@ class FinetuningArguments:
with open(json_path, "r", encoding="utf-8") as f:
text = f.read()
return cls(**json.loads(text))
@dataclass
class GeneratingArguments:
"""
Arguments pertaining to specify the decoding parameters.
"""
do_sample: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
)
temperature: Optional[float] = field(
default=0.95,
metadata={"help": "The value used to modulate the next token probabilities."}
)
top_p: Optional[float] = field(
default=0.7,
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
)
top_k: Optional[int] = field(
default=50,
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
)
infer_num_beams: Optional[int] = field(
default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
)
max_new_tokens: Optional[int] = field(
default=512,
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
)
repetition_penalty: Optional[float] = field(
default=1.0,
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
)
def to_dict(self) -> Dict[str, Any]:
data_dict = asdict(self)
num_beams = data_dict.pop("infer_num_beams")
data_dict["num_beams"] = num_beams
return data_dict