[v1] add accelerator (#9607)

This commit is contained in:
Yaowei Zheng
2025-12-12 19:22:06 +08:00
committed by GitHub
parent 4fd94141a4
commit 203069e11c
36 changed files with 941 additions and 443 deletions

View File

@@ -13,22 +13,21 @@
# limitations under the License.
from ..config.parser import get_args
from ..accelerator.interface import DistributedInterface, DistributedStrategy
from ..config.arg_parser import get_args
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
from ..core.model_worker import ModelWorker
class SFTTrainer(BaseTrainer):
pass
def run_sft():
model_args, data_args, training_args, _ = get_args()
model_engine = ModelEngine(model_args)
def run_sft(user_args):
model_args, data_args, training_args, _ = get_args(user_args)
DistributedInterface(DistributedStrategy())
data_engine = DataEngine(data_args)
model = model_engine.get_model()
processor = model_engine.get_processor()
data_loader = data_engine.get_data_loader(processor)
trainer = SFTTrainer(training_args, model, processor, data_loader)
model_worker = ModelWorker(model_args)
trainer = SFTTrainer(training_args, model_worker, data_engine)
trainer.fit()