support BLOOM models
Former-commit-id: 1314b6ea39a01aa8ac325e1d875ac013d43aec45
This commit is contained in:
@@ -12,6 +12,12 @@ class DatasetAttr:
|
||||
file_name: Optional[str] = None
|
||||
file_sha1: Optional[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.dataset_name is not None:
|
||||
return self.dataset_name
|
||||
else:
|
||||
return self.file_name
|
||||
|
||||
def __post_init__(self):
|
||||
self.prompt_column = "instruction"
|
||||
self.query_column = "input"
|
||||
@@ -161,9 +167,11 @@ class FinetuningArguments:
|
||||
default=3,
|
||||
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
|
||||
)
|
||||
name_module_trainable: Optional[Literal["mlp", "qkv"]] = field(
|
||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
||||
default="mlp",
|
||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning."}
|
||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM choices: [\"mlp\", \"self_attention\"]"}
|
||||
)
|
||||
lora_rank: Optional[int] = field(
|
||||
default=8,
|
||||
@@ -171,7 +179,7 @@ class FinetuningArguments:
|
||||
)
|
||||
lora_alpha: Optional[float] = field(
|
||||
default=32.0,
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning. (similar with the learning rate)"}
|
||||
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
|
||||
)
|
||||
lora_dropout: Optional[float] = field(
|
||||
default=0.1,
|
||||
@@ -179,7 +187,9 @@ class FinetuningArguments:
|
||||
)
|
||||
lora_target: Optional[str] = field(
|
||||
default="q_proj,v_proj",
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."}
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
|
||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"mlp\"], \
|
||||
BLOOM choices: [\"query_key_value\", \"dense\", \"mlp\"]"}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -191,11 +201,7 @@ class FinetuningArguments:
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
||||
|
||||
if self.name_module_trainable == "mlp":
|
||||
self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids]
|
||||
elif self.name_module_trainable == "qkv":
|
||||
self.trainable_layers = ["layers.{:d}.self_attn.{}".format(idx, proj) \
|
||||
for proj in ["k_proj", "q_proj", "v_proj", "o_proj"] for idx in trainable_layer_ids]
|
||||
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
||||
|
||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||
|
||||
|
||||
Reference in New Issue
Block a user