[v1] add batch generator (#9744)

This commit is contained in:
Yaowei Zheng
2026-01-10 04:24:09 +08:00
committed by GitHub
parent d7d734d54c
commit b2effbd77c
26 changed files with 604 additions and 850 deletions

View File

@@ -14,7 +14,7 @@
from ..accelerator.interface import DistributedInterface
from ..config.arg_parser import get_args
from ..config import InputArgument, get_args
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
@@ -24,15 +24,15 @@ class SFTTrainer(BaseTrainer):
pass
def run_sft(user_args):
model_args, data_args, training_args, _ = get_args(user_args)
def run_sft(args: InputArgument = None):
model_args, data_args, training_args, _ = get_args(args)
DistributedInterface(training_args.dist_config)
data_engine = DataEngine(data_args)
model_engine = ModelEngine(model_args)
trainer = SFTTrainer(
args=training_args,
model=model_engine.model,
processor=model_engine.processor,
renderer=model_engine.renderer,
dataset=data_engine,
)
trainer.fit()