Former-commit-id: 819cc1353599e5fa45658bc56dd0dbe4b258b197
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from peft import (
|
||||
PeftModel,
|
||||
TaskType,
|
||||
@@ -12,19 +12,22 @@ from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import load_trainable_params
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def init_adapter(
|
||||
model: PreTrainedModel,
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
is_mergeable: bool
|
||||
) -> PreTrainedModel:
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Literal, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
@@ -16,11 +16,13 @@ from transformers.tokenization_utils import PreTrainedTokenizerBase
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params
|
||||
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core.adapter import init_adapter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -33,8 +35,8 @@ require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
|
||||
@@ -141,6 +143,9 @@ def load_model_and_tokenizer(
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
||||
|
||||
print_trainable_params(model)
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
))
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
@@ -19,20 +19,39 @@ from llmtuner.hparams import (
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
return parser.parse_args_into_dataclasses()
|
||||
|
||||
|
||||
def parse_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
|
||||
parser = HfArgumentParser((
|
||||
GeneralArguments, ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def parse_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
|
||||
|
||||
if args is not None:
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
|
||||
general_args, model_args, data_args, training_args, finetuning_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
@@ -73,13 +92,22 @@ def get_train_args(
|
||||
if training_args.do_train and (not training_args.fp16):
|
||||
logger.warning("We recommend enable fp16 mixed precision training.")
|
||||
|
||||
if data_args.prompt_template == "default":
|
||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||
|
||||
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
|
||||
if (
|
||||
training_args.local_rank != -1
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
if data_args.max_samples is not None and data_args.streaming:
|
||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling streaming mode.")
|
||||
data_args.streaming = False
|
||||
|
||||
if data_args.dev_ratio > 1e-6 and data_args.streaming:
|
||||
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
|
||||
data_args.dev_ratio = 0
|
||||
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
@@ -106,17 +134,7 @@ def get_train_args(
|
||||
def get_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
|
||||
|
||||
if args is not None:
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
|
||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||
|
||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||
"Quantization is only compatible with the LoRA method."
|
||||
@@ -128,7 +146,4 @@ def get_infer_args(
|
||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
||||
"Quantized model only accepts a single checkpoint."
|
||||
|
||||
if data_args.prompt_template == "default":
|
||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Dict, Optional
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
|
||||
from transformers.modeling_utils import PreTrainedModel, unwrap_model
|
||||
from peft import PeftModel
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -21,7 +24,7 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self._remove_log()
|
||||
@@ -42,31 +45,35 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||
|
||||
model = unwrap_model(self.model)
|
||||
state_dict = state_dict or get_state_dict(model)
|
||||
|
||||
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
|
||||
backbone_model = getattr(model, "pretrained_model")
|
||||
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
else:
|
||||
backbone_model = model
|
||||
if isinstance(model, PreTrainedModelWrapper):
|
||||
model_params, v_head_params = {}, {}
|
||||
for name in state_dict.keys():
|
||||
if name.startswith("pretrained_model."):
|
||||
model_params[name.replace("pretrained_model.", "")] = state_dict[name]
|
||||
elif name.startswith("v_head."):
|
||||
v_head_params[name.replace("v_head.", "")] = state_dict[name]
|
||||
|
||||
if isinstance(backbone_model, PeftModel): # LoRA tuning
|
||||
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
|
||||
elif isinstance(backbone_model, PreTrainedModel): # freeze/full tuning
|
||||
backbone_model.config.use_cache = True
|
||||
backbone_model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=get_state_dict(backbone_model, trainable_only=(self.finetuning_args.finetuning_type != "full")),
|
||||
safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
backbone_model.config.use_cache = False
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
torch.save(v_head_params, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
state_dict = model_params
|
||||
model = model.pretrained_model
|
||||
|
||||
if isinstance(model, (PeftModel, PreTrainedModel)):
|
||||
model.config.use_cache = True
|
||||
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
|
||||
model.config.use_cache = False
|
||||
else:
|
||||
logger.warning("No model to save.")
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||
f.write(self.args.to_json_string() + "\n")
|
||||
|
||||
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
||||
|
||||
def _load_best_model(self):
|
||||
@@ -76,16 +83,15 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||
|
||||
model = unwrap_model(self.model)
|
||||
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
|
||||
|
||||
if isinstance(backbone_model, PeftModel):
|
||||
backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter)
|
||||
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
})
|
||||
if isinstance(model, PreTrainedModelWrapper):
|
||||
model.v_head.load_state_dict(torch.load(
|
||||
os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
|
||||
))
|
||||
model = model.pretrained_model
|
||||
|
||||
if isinstance(model, PeftModel):
|
||||
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
||||
else: # freeze/full-tuning
|
||||
load_trainable_params(backbone_model, self.state.best_model_checkpoint)
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
|
||||
@@ -2,21 +2,25 @@ import os
|
||||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl
|
||||
from transformers import TrainerState, TrainerControl
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from trl import PPOTrainer
|
||||
from trl.core import LengthSampler
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, get_logits_processor
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -27,9 +31,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: List[LogCallback],
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: List["LogCallback"],
|
||||
**kwargs
|
||||
):
|
||||
PPOTrainer.__init__(self, **kwargs)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import torch
|
||||
from typing import Dict, List, Literal, Optional, Tuple
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save default head temporarily
|
||||
valuehead_state_dict = model.v_head.state_dict()
|
||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
|
||||
@@ -19,10 +21,10 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
|
||||
|
||||
|
||||
def cast_layernorm_dtype(
|
||||
model: AutoModelForCausalLMWithValueHead,
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
layer_norm_names: List[str] = LAYERNORM_NAMES,
|
||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
||||
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
|
||||
|
||||
layer_norm_state_dict = {}
|
||||
|
||||
|
||||
@@ -2,26 +2,30 @@
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
from trl import PPOConfig
|
||||
from torch.optim import AdamW
|
||||
from typing import Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from transformers.optimization import get_scheduler
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
def run_ppo(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||
|
||||
@@ -1,24 +1,27 @@
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
|
||||
|
||||
import math
|
||||
from typing import Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
def run_pt(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||
|
||||
@@ -15,5 +15,8 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
|
||||
features = [
|
||||
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])}
|
||||
for key in ("accept_ids", "reject_ids") for feature in features
|
||||
]
|
||||
return super().__call__(features)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from transformers.trainer import PredictionOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -23,7 +25,7 @@ class PairwisePeftTrainer(PeftTrainer):
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
model: "PreTrainedModel",
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
return_outputs: Optional[bool] = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
@@ -46,7 +48,7 @@ class PairwisePeftTrainer(PeftTrainer):
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predict_results: PredictionOutput
|
||||
predict_results: "PredictionOutput"
|
||||
) -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
|
||||
@@ -2,25 +2,27 @@
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
||||
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
|
||||
from typing import Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
def run_rm(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||
|
||||
import jieba
|
||||
from rouge_chinese import Rouge
|
||||
@@ -9,6 +8,9 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeMetrics:
|
||||
@@ -16,7 +18,7 @@ class ComputeMetrics:
|
||||
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizer
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
|
||||
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
r"""
|
||||
|
||||
@@ -3,13 +3,15 @@ import json
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from transformers.trainer import PredictionOutput
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -81,7 +83,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predict_results: PredictionOutput
|
||||
predict_results: "PredictionOutput"
|
||||
) -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
|
||||
@@ -1,25 +1,28 @@
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
||||
|
||||
from typing import Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
def run_sft(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
|
||||
Reference in New Issue
Block a user