update scripts

Former-commit-id: dabf5a1dc661a6581474c6a5ec115322d168ed5f
This commit is contained in:
hiyouga
2024-08-09 19:16:23 +08:00
parent 5af32ce705
commit 9d1e2c3c1f
8 changed files with 29 additions and 16 deletions

View File

@@ -43,7 +43,7 @@ def calculate_lr(
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate,
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False,
):
r"""
@@ -84,7 +84,7 @@ def calculate_lr(
valid_ratio = valid_tokens / total_tokens
batch_valid_len = batch_max_len * valid_ratio
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral else lr
lr = lr / 6.0 if is_mistral_or_gemma else lr
print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
lr, valid_ratio * 100, batch_valid_len

View File

@@ -19,7 +19,7 @@
import json
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
import fire
import torch
@@ -47,8 +47,8 @@ def block_expansion(
model_name_or_path: str,
output_dir: str,
num_expand: int,
shard_size: Optional[str] = "2GB",
save_safetensors: Optional[bool] = False,
shard_size: str = "2GB",
save_safetensors: bool = True,
):
r"""
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.

View File

@@ -16,7 +16,7 @@
import json
import os
from collections import OrderedDict
from typing import Any, Dict, Optional
from typing import Any, Dict
import fire
import torch
@@ -86,7 +86,10 @@ def save_config(input_dir: str, output_dir: str):
def llamafy_baichuan2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
input_dir: str,
output_dir: str,
shard_size: str = "2GB",
save_safetensors: bool = True,
):
r"""
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.

View File

@@ -16,7 +16,7 @@
import json
import os
from collections import OrderedDict
from typing import Any, Dict, Optional
from typing import Any, Dict
import fire
import torch
@@ -139,7 +139,10 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
def llamafy_qwen(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
input_dir: str,
output_dir: str,
shard_size: str = "2GB",
save_safetensors: bool = False,
):
r"""
Converts the Qwen models in the same format as LLaMA2.

View File

@@ -31,7 +31,7 @@ if TYPE_CHECKING:
def quantize_pissa(
model_name_or_path: str,
output_dir: str,
pissa_iter: int = 4,
pissa_iter: int = 16,
lora_alpha: int = None,
lora_rank: int = 16,
lora_dropout: float = 0,