alter rewards data type

Former-commit-id: 3eb7eb2d37525da50fe401ab7c59532e6e1ef984
This commit is contained in:
hiyouga
2023-06-02 14:19:51 +08:00
parent 896dbfec16
commit e9ab06678f
12 changed files with 40 additions and 50 deletions

View File

@@ -42,8 +42,7 @@ from .other import (
load_valuehead_params,
print_trainable_params,
prepare_model_for_training,
IGNORE_INDEX,
FINETUNING_ARGS_NAME
IGNORE_INDEX
)
check_min_version("4.29.1")
@@ -128,7 +127,7 @@ def init_adapter(
def load_pretrained(
model_args: ModelArguments,
finetuning_args: Optional[FinetuningArguments] = None,
finetuning_args: FinetuningArguments,
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
@@ -137,16 +136,9 @@ def load_pretrained(
Support both training and inference.
"""
if finetuning_args is None: # load the fine-tuning arguments
if model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)):
finetuning_args = FinetuningArguments.load_from_json(
os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)
)
else:
raise ValueError("Missing fine-tuning arguments in the provided dictionary.")
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with LoRA method."