Merge pull request #2746 from stephen-nju/main

fix deepspeed ppo RuntimeError

Former-commit-id: 656c653f0c628f9494b4d7ae12e60c8eeec1ea7a
This commit is contained in:
hoshi-hiyouga
2024-03-09 01:37:00 +08:00
committed by GitHub
2 changed files with 5 additions and 3 deletions

View File

@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
from ..hparams import ModelArguments,FinetuningArguments
logger = get_logger(__name__)
@@ -265,6 +265,7 @@ def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
init_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
@@ -289,7 +290,8 @@ def patch_config(
if not is_deepspeed_zero3_enabled():
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
if "device_map" not in init_kwargs: # quant models cannot use auto device map
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
if finetuning_args.stage not in ["ppo"]: #ppo stage should not set device map
init_kwargs["device_map"] = model_args.device_map or {"": get_current_device()}
def patch_model(