format style

Former-commit-id: 53b683531b83cd1d19de97c6565f16c1eca6f5e1
This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent 1750218057
commit 66e0e651b9
73 changed files with 1492 additions and 2325 deletions

View File

@@ -3,32 +3,28 @@
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
import os
import fire
import json
import torch
from tqdm import tqdm
import os
from collections import OrderedDict
from safetensors.torch import save_file
from transformers.modeling_utils import (
shard_checkpoint,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME
)
from typing import Any, Dict, Optional
import fire
import torch
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
CONFIG_NAME = "config.json"
def save_weight(
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: bool
):
def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
internlm2_config_dict: Dict[str, Any] = json.load(f)
@@ -50,8 +46,10 @@ def save_weight(
q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[
q_size : q_size + kv_size, ...
]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size + kv_size :, ...]
elif "wo" in key:
llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
elif "attention_norm" in key:
@@ -85,10 +83,7 @@ def save_weight(
print("Model weights saved in {}".format(output_dir))
def save_config(
input_dir: str,
output_dir: str
):
def save_config(input_dir: str, output_dir: str):
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
llama2_config_dict: Dict[str, Any] = json.load(f)
@@ -103,12 +98,7 @@ def save_config(
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_internlm2(
input_dir: str,
output_dir: str,
shard_size: str,
save_safetensors: Optional[bool] = False
):
def llamafy_internlm2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
try:
os.makedirs(output_dir, exist_ok=False)
except Exception as e: