Former-commit-id: 392bdaf1ea9e5baf6289f2d4415a175dd55a479d
This commit is contained in:
hiyouga
2024-09-11 17:36:42 +08:00
parent 588ea95732
commit 7fd0d2fc2f
4 changed files with 12 additions and 22 deletions

View File

@@ -31,7 +31,7 @@ if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]:
r"""
Gets reward scores from the API server.
"""
@@ -66,7 +66,7 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
r"""
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""
@@ -79,7 +79,7 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
return layer_norm_params
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
r"""
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
"""