mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-19 23:33:09 +00:00
[v1] add sft (#9752)
This commit is contained in:
@@ -18,21 +18,35 @@ from ..config import InputArgument, get_args
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_engine import ModelEngine
|
||||
from ..utils.types import BatchInput, Tensor
|
||||
|
||||
|
||||
class SFTTrainer(BaseTrainer):
|
||||
pass
|
||||
def compute_loss(self, batch: BatchInput) -> Tensor:
|
||||
shift_loss_weights = batch["loss_weights"].to(self.device, non_blocking=True)[..., 1:]
|
||||
log_probs = self.compute_log_probs(self.model, batch)
|
||||
loss = (-log_probs * shift_loss_weights).sum() / (shift_loss_weights.sum() + 1e-6)
|
||||
return loss
|
||||
|
||||
|
||||
def run_sft(args: InputArgument = None):
|
||||
model_args, data_args, training_args, _ = get_args(args)
|
||||
DistributedInterface(training_args.dist_config)
|
||||
data_engine = DataEngine(data_args)
|
||||
train_dataset = DataEngine(data_args.train_dataset)
|
||||
model_engine = ModelEngine(model_args)
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model=model_engine.model,
|
||||
renderer=model_engine.renderer,
|
||||
dataset=data_engine,
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
trainer.fit()
|
||||
trainer.save_model()
|
||||
DistributedInterface().destroy()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.trainers.sft_trainer --model Qwen/Qwen3-0.6B --train_dataset data/v1_sft_demo.yaml
|
||||
"""
|
||||
run_sft()
|
||||
|
||||
Reference in New Issue
Block a user