support BLOOM models

Former-commit-id: 1314b6ea39a01aa8ac325e1d875ac013d43aec45
This commit is contained in:
hiyouga
2023-05-31 16:54:06 +08:00
parent 181c776b58
commit 693c049eac
16 changed files with 134 additions and 90 deletions

View File

@@ -12,6 +12,12 @@ class DatasetAttr:
file_name: Optional[str] = None
file_sha1: Optional[str] = None
def __repr__(self) -> str:
if self.dataset_name is not None:
return self.dataset_name
else:
return self.file_name
def __post_init__(self):
self.prompt_column = "instruction"
self.query_column = "input"
@@ -161,9 +167,11 @@ class FinetuningArguments:
default=3,
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
)
name_module_trainable: Optional[Literal["mlp", "qkv"]] = field(
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
default="mlp",
metadata={"help": "Name of trainable modules for Freeze fine-tuning."}
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM choices: [\"mlp\", \"self_attention\"]"}
)
lora_rank: Optional[int] = field(
default=8,
@@ -171,7 +179,7 @@ class FinetuningArguments:
)
lora_alpha: Optional[float] = field(
default=32.0,
metadata={"help": "The scale factor for LoRA fine-tuning. (similar with the learning rate)"}
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
)
lora_dropout: Optional[float] = field(
default=0.1,
@@ -179,7 +187,9 @@ class FinetuningArguments:
)
lora_target: Optional[str] = field(
default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."}
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"mlp\"], \
BLOOM choices: [\"query_key_value\", \"dense\", \"mlp\"]"}
)
def __post_init__(self):
@@ -191,11 +201,7 @@ class FinetuningArguments:
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
if self.name_module_trainable == "mlp":
self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids]
elif self.name_module_trainable == "qkv":
self.trainable_layers = ["layers.{:d}.self_attn.{}".format(idx, proj) \
for proj in ["k_proj", "q_proj", "v_proj", "o_proj"] for idx in trainable_layer_ids]
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."