simplify code

Former-commit-id: d3731754ab7c28ae81f60784e0e4213f279d93fe
This commit is contained in:
hiyouga
2023-07-20 15:08:57 +08:00
parent 7a43ff3d89
commit 5ba0b80e5c
18 changed files with 52 additions and 136 deletions

View File

@@ -4,7 +4,7 @@ import math
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
@@ -28,16 +28,6 @@ def run_pt(
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer
trainer = PeftTrainer(
finetuning_args=finetuning_args,
@@ -46,7 +36,7 @@ def run_pt(
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**trainer_kwargs
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
)
# Training

View File

@@ -5,7 +5,7 @@
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
@@ -29,16 +29,6 @@ def run_rm(
training_args.remove_unused_columns = False # important for pairwise dataset
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer
trainer = PairwisePeftTrainer(
finetuning_args=finetuning_args,
@@ -48,7 +38,7 @@ def run_rm(
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=compute_accuracy,
**trainer_kwargs
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
)
# Training

View File

@@ -3,7 +3,7 @@
from typing import Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
@@ -35,16 +35,6 @@ def run_sft(
training_args.generation_num_beams = data_args.eval_num_beams if \
data_args.eval_num_beams is not None else training_args.generation_num_beams
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer
trainer = Seq2SeqPeftTrainer(
finetuning_args=finetuning_args,
@@ -54,7 +44,7 @@ def run_sft(
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**trainer_kwargs
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
)
# Keyword arguments for `model.generate`