[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@@ -44,11 +44,11 @@ def block_expansion(
|
||||
shard_size: str = "5GB",
|
||||
save_safetensors: bool = True,
|
||||
):
|
||||
r"""
|
||||
Performs block expansion for LLaMA, Mistral, Qwen2 or Yi models.
|
||||
r"""Perform block expansion for LLaMA, Mistral, Qwen2 or Yi models.
|
||||
|
||||
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
||||
"""
|
||||
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
num_layers = getattr(config, "num_hidden_layers")
|
||||
if num_layers % num_expand != 0:
|
||||
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
|
||||
@@ -70,7 +70,7 @@ def block_expansion(
|
||||
split = num_layers // num_expand
|
||||
layer_cnt = 0
|
||||
state_dict = model.state_dict()
|
||||
output_state_dict: Dict[str, "torch.Tensor"] = OrderedDict()
|
||||
output_state_dict: dict[str, torch.Tensor] = OrderedDict()
|
||||
for i in range(num_layers):
|
||||
for key, value in state_dict.items():
|
||||
if f".{i:d}." in key:
|
||||
|
||||
Reference in New Issue
Block a user