[data] fix loader (#7207)

* fix dataloader

* add test case

* fix type

* fix ci

* fix ci

* fix ci

* disable overwrite cache in ci

Former-commit-id: e84af0e140b1aafd1a6d6fe185a8e41c8fc5f831
This commit is contained in:
hoshi-hiyouga
2025-03-07 17:20:46 +08:00
committed by GitHub
parent 82a2bac866
commit 16419b2834
16 changed files with 161 additions and 92 deletions

View File

@@ -26,10 +26,11 @@ from ..model import load_model, load_tokenizer
if TYPE_CHECKING:
from datasets import Dataset
from peft import LoraModel
from transformers import PreTrainedModel
from ..data.data_utils import DatasetModule
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []) -> None:
state_dict_a = model_a.state_dict()
@@ -101,12 +102,12 @@ def load_reference_model(
return model
def load_train_dataset(**kwargs) -> "Dataset":
def load_dataset_module(**kwargs) -> "DatasetModule":
model_args, data_args, training_args, _, _ = get_train_args(kwargs)
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, kwargs["stage"], **tokenizer_module)
return dataset_module["train_dataset"]
return dataset_module
def patch_valuehead_model() -> None: