simplify code
Former-commit-id: d3731754ab7c28ae81f60784e0e4213f279d93fe
This commit is contained in:
@@ -4,7 +4,7 @@ import math
|
||||
from typing import Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
@@ -28,16 +28,6 @@ def run_pt(
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
if data_args.dev_ratio > 1e-6:
|
||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
trainer_kwargs = {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
trainer_kwargs = {"eval_dataset": dataset}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
@@ -46,7 +36,7 @@ def run_pt(
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**trainer_kwargs
|
||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||
)
|
||||
|
||||
# Training
|
||||
|
||||
Reference in New Issue
Block a user