mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-02 20:43:38 +00:00
[v1] model loader (#9613)
This commit is contained in:
@@ -13,11 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ..accelerator.interface import DistributedInterface, DistributedStrategy
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config.arg_parser import get_args
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_worker import ModelWorker
|
||||
from ..core.model_loader import ModelLoader
|
||||
|
||||
|
||||
class SFTTrainer(BaseTrainer):
|
||||
@@ -26,8 +26,13 @@ class SFTTrainer(BaseTrainer):
|
||||
|
||||
def run_sft(user_args):
|
||||
model_args, data_args, training_args, _ = get_args(user_args)
|
||||
DistributedInterface(DistributedStrategy())
|
||||
DistributedInterface(training_args.dist_config)
|
||||
data_engine = DataEngine(data_args)
|
||||
model_worker = ModelWorker(model_args)
|
||||
trainer = SFTTrainer(training_args, model_worker, data_engine)
|
||||
model_loader = ModelLoader(model_args)
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model=model_loader.model,
|
||||
processor=model_loader.processor,
|
||||
dataset=data_engine,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
Reference in New Issue
Block a user