add code for reading from multi files in one directory

Former-commit-id: b7ebb83a96619e5111b0faa9da9d0feb8d9cdff0
This commit is contained in:
BUAADreamer
2023-06-10 15:53:47 +08:00
parent 03c92c79ff
commit ef6c5ae18a
2 changed files with 76 additions and 54 deletions

View File

@@ -7,7 +7,6 @@ from dataclasses import asdict, dataclass, field
@dataclass
class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
file_name: Optional[str] = None
@@ -68,7 +67,8 @@ class ModelArguments:
)
checkpoint_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
metadata={
"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
)
reward_model: Optional[str] = field(
default=None,
@@ -76,7 +76,8 @@ class ModelArguments:
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
metadata={
"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
plot_loss: Optional[bool] = field(
default=False,
@@ -84,7 +85,7 @@ class ModelArguments:
)
def __post_init__(self):
if self.checkpoint_dir is not None: # support merging multiple lora weights
if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
@@ -146,7 +147,7 @@ class DataTrainingArguments:
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
def __post_init__(self): # support mixing multiple datasets
def __post_init__(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f)
@@ -155,25 +156,42 @@ class DataTrainingArguments:
for name in dataset_names:
if name not in dataset_info:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
dataset_attrs = []
dataset_attr = None
if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else:
elif os.path.isfile(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])):
dataset_attr = DatasetAttr(
"file",
file_name=dataset_info[name]["file_name"],
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
)
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
else:
# Support Directory
for file_name in os.listdir(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])):
path = os.path.join(dataset_info[name]["file_name"], file_name)
dataset_attrs.append(DatasetAttr(
"file",
file_name=path,
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
))
if dataset_attr is not None:
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
else:
for i, dataset_attr in enumerate(dataset_attrs):
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
@dataclass
@@ -216,14 +234,16 @@ class FinetuningArguments:
def __post_init__(self):
if isinstance(self.lora_target, str):
self.lora_target = [target.strip() for target in self.lora_target.split(",")] # support custom target modules of LoRA
self.lora_target = [target.strip() for target in
self.lora_target.split(",")] # support custom target modules of LoRA
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [27-k for k in range(self.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in
trainable_layer_ids]
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."