fix save function

Former-commit-id: 1d6beb0c8490a7531ffdf7a2819410597b200d12
This commit is contained in:
hiyouga
2023-07-21 14:09:07 +08:00
parent 49c90044ce
commit a1468139a5
2 changed files with 4 additions and 4 deletions

View File

@@ -1,6 +1,6 @@
import os
import torch
from typing import Dict
from typing import Dict, Optional
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import load_sharded_checkpoint
@@ -12,12 +12,12 @@ from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
def get_state_dict(model: torch.nn.Module, trainable_only: Optional[bool] = True) -> Dict[str, torch.Tensor]:
state_dict = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
if v.requires_grad:
if (not trainable_only) or v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
return filtered_state_dict