fix llamafy scripts

Former-commit-id: 99ff69c36767d4397a4a61e89317ec8c0c295c1e
This commit is contained in:
hiyouga
2024-01-18 00:37:37 +08:00
parent 344412e66e
commit 97b52c7fdf
3 changed files with 5 additions and 4 deletions

View File

@@ -31,7 +31,7 @@ def save_weight(
save_safetensors: bool
):
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for filepath in os.listdir(input_dir):
for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
baichuan2_state_dict.update(shard_weight)