[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent 5995800bce
commit 264538cb26
113 changed files with 984 additions and 1407 deletions

View File

@@ -18,7 +18,7 @@
import json
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
import fire
import torch
@@ -44,11 +44,11 @@ def block_expansion(
shard_size: str = "5GB",
save_safetensors: bool = True,
):
r"""
Performs block expansion for LLaMA, Mistral, Qwen2 or Yi models.
r"""Perform block expansion for LLaMA, Mistral, Qwen2 or Yi models.
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
"""
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
num_layers = getattr(config, "num_hidden_layers")
if num_layers % num_expand != 0:
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
@@ -70,7 +70,7 @@ def block_expansion(
split = num_layers // num_expand
layer_cnt = 0
state_dict = model.state_dict()
output_state_dict: Dict[str, "torch.Tensor"] = OrderedDict()
output_state_dict: dict[str, torch.Tensor] = OrderedDict()
for i in range(num_layers):
for key, value in state_dict.items():
if f".{i:d}." in key: