mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-03 08:53:38 +00:00
[v1] add batch generator (#9744)
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config.arg_parser import get_args
|
||||
from ..config import InputArgument, get_args
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_engine import ModelEngine
|
||||
@@ -24,15 +24,15 @@ class SFTTrainer(BaseTrainer):
|
||||
pass
|
||||
|
||||
|
||||
def run_sft(user_args):
|
||||
model_args, data_args, training_args, _ = get_args(user_args)
|
||||
def run_sft(args: InputArgument = None):
|
||||
model_args, data_args, training_args, _ = get_args(args)
|
||||
DistributedInterface(training_args.dist_config)
|
||||
data_engine = DataEngine(data_args)
|
||||
model_engine = ModelEngine(model_args)
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model=model_engine.model,
|
||||
processor=model_engine.processor,
|
||||
renderer=model_engine.renderer,
|
||||
dataset=data_engine,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
Reference in New Issue
Block a user