mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-03 21:03:10 +00:00
[v1] add accelerator (#9607)
This commit is contained in:
@@ -13,22 +13,21 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ..config.parser import get_args
|
||||
from ..accelerator.interface import DistributedInterface, DistributedStrategy
|
||||
from ..config.arg_parser import get_args
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_engine import ModelEngine
|
||||
from ..core.model_worker import ModelWorker
|
||||
|
||||
|
||||
class SFTTrainer(BaseTrainer):
|
||||
pass
|
||||
|
||||
|
||||
def run_sft():
|
||||
model_args, data_args, training_args, _ = get_args()
|
||||
model_engine = ModelEngine(model_args)
|
||||
def run_sft(user_args):
|
||||
model_args, data_args, training_args, _ = get_args(user_args)
|
||||
DistributedInterface(DistributedStrategy())
|
||||
data_engine = DataEngine(data_args)
|
||||
model = model_engine.get_model()
|
||||
processor = model_engine.get_processor()
|
||||
data_loader = data_engine.get_data_loader(processor)
|
||||
trainer = SFTTrainer(training_args, model, processor, data_loader)
|
||||
model_worker = ModelWorker(model_args)
|
||||
trainer = SFTTrainer(training_args, model_worker, data_engine)
|
||||
trainer.fit()
|
||||
|
||||
Reference in New Issue
Block a user