add code for reading from multi files in one directory
Former-commit-id: b7ebb83a96619e5111b0faa9da9d0feb8d9cdff0
This commit is contained in:
@@ -7,7 +7,6 @@ from dataclasses import asdict, dataclass, field
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
|
||||
load_from: str
|
||||
dataset_name: Optional[str] = None
|
||||
file_name: Optional[str] = None
|
||||
@@ -68,7 +67,8 @@ class ModelArguments:
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
metadata={
|
||||
"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -76,7 +76,8 @@ class ModelArguments:
|
||||
)
|
||||
resume_lora_training: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
metadata={
|
||||
"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field(
|
||||
default=False,
|
||||
@@ -84,7 +85,7 @@ class ModelArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
|
||||
|
||||
@@ -146,7 +147,7 @@ class DataTrainingArguments:
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
)
|
||||
|
||||
def __post_init__(self): # support mixing multiple datasets
|
||||
def __post_init__(self): # support mixing multiple datasets
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
@@ -155,25 +156,42 @@ class DataTrainingArguments:
|
||||
for name in dataset_names:
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
||||
|
||||
dataset_attrs = []
|
||||
dataset_attr = None
|
||||
if "hf_hub_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
else:
|
||||
elif os.path.isfile(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])):
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
file_name=dataset_info[name]["file_name"],
|
||||
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
|
||||
)
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
||||
|
||||
self.dataset_list.append(dataset_attr)
|
||||
else:
|
||||
# Support Directory
|
||||
for file_name in os.listdir(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])):
|
||||
path = os.path.join(dataset_info[name]["file_name"], file_name)
|
||||
dataset_attrs.append(DatasetAttr(
|
||||
"file",
|
||||
file_name=path,
|
||||
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
|
||||
))
|
||||
if dataset_attr is not None:
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
||||
self.dataset_list.append(dataset_attr)
|
||||
else:
|
||||
for i, dataset_attr in enumerate(dataset_attrs):
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
||||
self.dataset_list.append(dataset_attr)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -216,14 +234,16 @@ class FinetuningArguments:
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.lora_target, str):
|
||||
self.lora_target = [target.strip() for target in self.lora_target.split(",")] # support custom target modules of LoRA
|
||||
self.lora_target = [target.strip() for target in
|
||||
self.lora_target.split(",")] # support custom target modules of LoRA
|
||||
|
||||
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [27-k for k in range(self.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
||||
|
||||
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
||||
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in
|
||||
trainable_layer_ids]
|
||||
|
||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||
|
||||
|
||||
Reference in New Issue
Block a user