ppo support rm server

Former-commit-id: 20b0edf16f5b42cb2c4a795674647afb68cb3a4a
This commit is contained in:
hiyouga
2023-12-03 21:38:51 +08:00
parent 29545d0e5e
commit 60aea7521b
5 changed files with 47 additions and 15 deletions

View File

@@ -1,10 +1,24 @@
import json
import torch
from typing import TYPE_CHECKING, Dict, Literal, Optional
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
from llmtuner.extras.packages import is_requests_available
if TYPE_CHECKING:
from transformers import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead
if is_requests_available():
import requests
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
headers = {"Content-Type": "application/json"}
payload = {"model": "model", "messages": messages}
response = requests.post(server_url, json=payload, headers=headers)
rewards = json.loads(response.text)["scores"]
return torch.Tensor(rewards)
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily