[v1] model loader (#9613)

This commit is contained in:
Yaowei Zheng
2025-12-14 11:50:52 +08:00
committed by GitHub
parent fdd24276ed
commit aeda079014
27 changed files with 449 additions and 305 deletions

View File

@@ -13,11 +13,11 @@
# limitations under the License.
from ..accelerator.interface import DistributedInterface, DistributedStrategy
from ..accelerator.interface import DistributedInterface
from ..config.arg_parser import get_args
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_worker import ModelWorker
from ..core.model_loader import ModelLoader
class SFTTrainer(BaseTrainer):
@@ -26,8 +26,13 @@ class SFTTrainer(BaseTrainer):
def run_sft(user_args):
model_args, data_args, training_args, _ = get_args(user_args)
DistributedInterface(DistributedStrategy())
DistributedInterface(training_args.dist_config)
data_engine = DataEngine(data_args)
model_worker = ModelWorker(model_args)
trainer = SFTTrainer(training_args, model_worker, data_engine)
model_loader = ModelLoader(model_args)
trainer = SFTTrainer(
args=training_args,
model=model_loader.model,
processor=model_loader.processor,
dataset=data_engine,
)
trainer.fit()