mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 20:43:38 +00:00
[v1] add cli sampler (#9721)
This commit is contained in:
@@ -17,7 +17,7 @@ from ..accelerator.interface import DistributedInterface
|
||||
from ..config.arg_parser import get_args
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_loader import ModelLoader
|
||||
from ..core.model_engine import ModelEngine
|
||||
|
||||
|
||||
class SFTTrainer(BaseTrainer):
|
||||
@@ -28,11 +28,11 @@ def run_sft(user_args):
|
||||
model_args, data_args, training_args, _ = get_args(user_args)
|
||||
DistributedInterface(training_args.dist_config)
|
||||
data_engine = DataEngine(data_args)
|
||||
model_loader = ModelLoader(model_args)
|
||||
model_engine = ModelEngine(model_args)
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model=model_loader.model,
|
||||
processor=model_loader.processor,
|
||||
model=model_engine.model,
|
||||
processor=model_engine.processor,
|
||||
dataset=data_engine,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
Reference in New Issue
Block a user