modity code structure
Former-commit-id: 0682ed357210897e0b67c4a6eb31a94b3eb929f1
This commit is contained in:
2
src/llmtuner/dsets/__init__.py
Normal file
2
src/llmtuner/dsets/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from llmtuner.dsets.loader import get_dataset
|
||||
from llmtuner.dsets.preprocess import preprocess_dataset
|
||||
63
src/llmtuner/dsets/callbacks.py
Normal file
63
src/llmtuner/dsets/callbacks.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments
|
||||
)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
def __init__(self, runner=None):
|
||||
self.runner = runner
|
||||
self.start_time = time.time()
|
||||
self.tracker = {}
|
||||
|
||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
||||
might take several inputs.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if "loss" not in state.log_history[-1]:
|
||||
return
|
||||
cur_time = time.time()
|
||||
cur_steps = state.log_history[-1].get("step")
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||
remaining_steps = state.max_steps - cur_steps
|
||||
remaining_time = remaining_steps * avg_time_per_step
|
||||
self.tracker = {
|
||||
"current_steps": cur_steps,
|
||||
"total_steps": state.max_steps,
|
||||
"loss": state.log_history[-1].get("loss", None),
|
||||
"reward": state.log_history[-1].get("reward", None),
|
||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
||||
"epoch": state.log_history[-1].get("epoch", None),
|
||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
||||
}
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.tracker) + "\n")
|
||||
106
src/llmtuner/dsets/loader.py
Normal file
106
src/llmtuner/dsets/loader.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
import hashlib
|
||||
from typing import List
|
||||
|
||||
from datasets import Dataset, concatenate_datasets, load_dataset
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import ModelArguments, DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_dataset(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments
|
||||
) -> Dataset:
|
||||
|
||||
def checksum(file_path, hash):
|
||||
with open(file_path, "rb") as datafile:
|
||||
binary_data = datafile.read()
|
||||
sha1 = hashlib.sha1(binary_data).hexdigest()
|
||||
if sha1 != hash:
|
||||
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
||||
|
||||
ext2type = {
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json",
|
||||
"txt": "text"
|
||||
}
|
||||
|
||||
max_samples = data_args.max_samples
|
||||
all_datasets: List[Dataset] = [] # support multiple datasets
|
||||
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
|
||||
if dataset_attr.load_from == "hf_hub":
|
||||
data_path = dataset_attr.dataset_name
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "script":
|
||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_path = None
|
||||
data_files: List[str] = []
|
||||
|
||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
||||
|
||||
if data_path is None:
|
||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||
else:
|
||||
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
|
||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
|
||||
assert data_path, "File extension must be txt, csv, json or jsonl."
|
||||
|
||||
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
|
||||
checksum(data_files[0], dataset_attr.dataset_sha1)
|
||||
else:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
raw_datasets = load_dataset(
|
||||
data_path,
|
||||
data_files=data_files,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None
|
||||
)
|
||||
dataset = raw_datasets[data_args.split]
|
||||
|
||||
if max_samples is not None:
|
||||
max_samples_temp = min(len(dataset), max_samples)
|
||||
dataset = dataset.select(range(max_samples_temp))
|
||||
|
||||
dummy_data = [None] * len(dataset)
|
||||
prefix_data = [dataset_attr.source_prefix] * len(dataset)
|
||||
for column_name, target_name in [
|
||||
("prompt_column", "prompt"),
|
||||
("query_column", "query"),
|
||||
("response_column", "response"),
|
||||
("history_column", "history")
|
||||
]: # every dataset will have 4 columns same as each other
|
||||
if getattr(dataset_attr, column_name) != target_name:
|
||||
if getattr(dataset_attr, column_name):
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
||||
else: # None or empty string
|
||||
dataset = dataset.add_column(target_name, dummy_data)
|
||||
dataset = dataset.add_column("prefix", prefix_data)
|
||||
all_datasets.append(dataset)
|
||||
|
||||
if len(data_args.dataset_list) == 1:
|
||||
all_datasets = all_datasets[0]
|
||||
else:
|
||||
all_datasets = concatenate_datasets(all_datasets)
|
||||
|
||||
return all_datasets
|
||||
172
src/llmtuner/dsets/preprocess.py
Normal file
172
src/llmtuner/dsets/preprocess.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from typing import Literal
|
||||
from itertools import chain
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.template import Template
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
def preprocess_dataset(
|
||||
dataset: Dataset,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Dataset:
|
||||
|
||||
column_names = list(dataset.column_names)
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
|
||||
# support question with a single answer or multiple answers
|
||||
def get_dialog(examples):
|
||||
for i in range(len(examples["prompt"])):
|
||||
if examples["prompt"][i] and examples["response"][i]:
|
||||
query, answer = examples["prompt"][i], examples["response"][i]
|
||||
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
|
||||
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
|
||||
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
|
||||
yield dialog
|
||||
|
||||
def preprocess_pretrain_dataset(examples):
|
||||
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
|
||||
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
||||
concatenated_ids = list(chain(*text_ids))
|
||||
total_length = len(concatenated_ids)
|
||||
block_size = data_args.max_source_length - 1
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of max_source_length
|
||||
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
|
||||
for i in range(0, total_length, block_size)]
|
||||
return {
|
||||
"input_ids": result,
|
||||
"labels": result.copy()
|
||||
}
|
||||
|
||||
def preprocess_supervised_dataset(examples):
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for input with history, we build multiple input-label pairs just like:
|
||||
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
for dialog in get_dialog(examples):
|
||||
input_ids, labels = [], []
|
||||
|
||||
for i in range(len(dialog) // 2):
|
||||
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
|
||||
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||
|
||||
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
|
||||
break
|
||||
|
||||
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
|
||||
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples):
|
||||
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||
target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length:
|
||||
target_ids = target_ids[:data_args.max_target_length]
|
||||
|
||||
model_inputs["input_ids"].append(source_ids)
|
||||
model_inputs["labels"].append(target_ids)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_pairwise_dataset(examples):
|
||||
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
||||
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
||||
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
||||
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
||||
|
||||
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
|
||||
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["accept_ids"].append(accept_ids)
|
||||
model_inputs["reject_ids"].append(reject_ids)
|
||||
return model_inputs
|
||||
|
||||
def print_supervised_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(
|
||||
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
|
||||
skip_special_tokens=False)
|
||||
))
|
||||
|
||||
def print_pairwise_dataset_example(example):
|
||||
print("accept_ids:\n{}".format(example["accept_ids"]))
|
||||
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
|
||||
print("reject_ids:\n{}".format(example["reject_ids"]))
|
||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
|
||||
|
||||
def print_unsupervised_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
if stage == "pt":
|
||||
preprocess_function = preprocess_pretrain_dataset
|
||||
elif stage == "sft":
|
||||
preprocess_function = preprocess_unsupervised_dataset \
|
||||
if training_args.predict_with_generate else preprocess_supervised_dataset
|
||||
elif stage == "rm":
|
||||
preprocess_function = preprocess_pairwise_dataset
|
||||
elif stage == "ppo":
|
||||
preprocess_function = preprocess_unsupervised_dataset
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
dataset = dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset"
|
||||
)
|
||||
|
||||
if stage == "pt":
|
||||
print_unsupervised_dataset_example(dataset[0])
|
||||
elif stage == "sft":
|
||||
print_supervised_dataset_example(dataset[0])
|
||||
elif stage == "rm":
|
||||
print_pairwise_dataset_example(dataset[0])
|
||||
elif stage == "ppo":
|
||||
print_unsupervised_dataset_example(dataset[0])
|
||||
|
||||
return dataset
|
||||
Reference in New Issue
Block a user