mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-24 00:53:07 +00:00
Compare commits
3 Commits
d3bf882e87
...
0779846513
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0779846513 | ||
|
|
45d335c709 | ||
|
|
816480012f |
@@ -14,16 +14,12 @@ dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
init_config:
|
||||
name: init_on_meta
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
|
||||
@@ -154,25 +154,24 @@ def vllm_infer(
|
||||
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
|
||||
|
||||
for j in range(len(batch["input_ids"])):
|
||||
multi_modal_data = {}
|
||||
video_metadata_kwargs = None
|
||||
|
||||
if batch["images"][j] is not None:
|
||||
image = batch["images"][j]
|
||||
multi_modal_data = {
|
||||
"image": template_obj.mm_plugin._regularize_images(
|
||||
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||
)["images"]
|
||||
}
|
||||
elif batch["videos"][j] is not None:
|
||||
video_metadata, video_metadata_kwargs = None, None
|
||||
multi_modal_data["image"] = template_obj.mm_plugin._regularize_images(
|
||||
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||
)["images"]
|
||||
|
||||
if batch["videos"][j] is not None:
|
||||
video = batch["videos"][j]
|
||||
multi_modal_data = {
|
||||
"video": template_obj.mm_plugin._regularize_videos(
|
||||
video,
|
||||
image_max_pixels=image_max_pixels,
|
||||
image_min_pixels=image_min_pixels,
|
||||
video_fps=video_fps,
|
||||
video_maxlen=video_maxlen,
|
||||
)["videos"]
|
||||
}
|
||||
multi_modal_data["video"] = template_obj.mm_plugin._regularize_videos(
|
||||
video,
|
||||
image_max_pixels=image_max_pixels,
|
||||
image_min_pixels=image_min_pixels,
|
||||
video_fps=video_fps,
|
||||
video_maxlen=video_maxlen,
|
||||
)["videos"]
|
||||
if need_video_kwargs:
|
||||
container = av.open(video[0], "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
@@ -192,18 +191,17 @@ def vllm_infer(
|
||||
video_backend="opencv",
|
||||
)
|
||||
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
|
||||
elif batch["audios"][j] is not None:
|
||||
|
||||
if batch["audios"][j] is not None:
|
||||
audio = batch["audios"][j]
|
||||
audio_data = template_obj.mm_plugin._regularize_audios(
|
||||
audio,
|
||||
sampling_rate=16000,
|
||||
)
|
||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
|
||||
|
||||
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
|
||||
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
|
||||
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data or None}
|
||||
if video_metadata_kwargs is not None:
|
||||
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
|
||||
|
||||
vllm_inputs.append(vllm_input_data)
|
||||
|
||||
@@ -180,35 +180,32 @@ class VllmEngine(BaseEngine):
|
||||
else self.generating_args["skip_special_tokens"],
|
||||
)
|
||||
|
||||
multi_modal_data = {}
|
||||
if images is not None: # add image features
|
||||
multi_modal_data = {
|
||||
"image": self.template.mm_plugin._regularize_images(
|
||||
images,
|
||||
image_max_pixels=self.model_args.image_max_pixels,
|
||||
image_min_pixels=self.model_args.image_min_pixels,
|
||||
)["images"]
|
||||
}
|
||||
elif videos is not None:
|
||||
multi_modal_data = {
|
||||
"video": self.template.mm_plugin._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=self.model_args.video_max_pixels,
|
||||
image_min_pixels=self.model_args.video_min_pixels,
|
||||
video_fps=self.model_args.video_fps,
|
||||
video_maxlen=self.model_args.video_maxlen,
|
||||
)["videos"]
|
||||
}
|
||||
elif audios is not None:
|
||||
multi_modal_data["image"] = self.template.mm_plugin._regularize_images(
|
||||
images,
|
||||
image_max_pixels=self.model_args.image_max_pixels,
|
||||
image_min_pixels=self.model_args.image_min_pixels,
|
||||
)["images"]
|
||||
|
||||
if videos is not None:
|
||||
multi_modal_data["video"] = self.template.mm_plugin._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=self.model_args.video_max_pixels,
|
||||
image_min_pixels=self.model_args.video_min_pixels,
|
||||
video_fps=self.model_args.video_fps,
|
||||
video_maxlen=self.model_args.video_maxlen,
|
||||
)["videos"]
|
||||
|
||||
if audios is not None:
|
||||
audio_data = self.template.mm_plugin._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=self.model_args.audio_sampling_rate,
|
||||
)
|
||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
|
||||
|
||||
result_generator = self.model.generate(
|
||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data or None},
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=self.lora_request,
|
||||
|
||||
@@ -395,6 +395,24 @@ _register_composite_model(
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_5",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_5_moe",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="video_llava",
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from omegaconf import OmegaConf
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from ..utils.env import is_env_enabled
|
||||
from ..utils.helper import set_seed
|
||||
from .data_args import DataArguments
|
||||
from .model_args import ModelArguments
|
||||
from .sample_args import SampleArguments
|
||||
@@ -56,6 +57,14 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
# Seed as early as possible after argument parsing so all downstream
|
||||
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
|
||||
for arg in parsed_args:
|
||||
seed = getattr(arg, "seed", None)
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
break
|
||||
|
||||
return tuple(parsed_args)
|
||||
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class TrainingArguments:
|
||||
metadata={"help": "Number of workers for batching."},
|
||||
)
|
||||
enable_activation_checkpointing: bool = field(
|
||||
default=True,
|
||||
default=False,
|
||||
metadata={"help": "Enable activation checkpointing for training."},
|
||||
)
|
||||
dist_config: PluginConfig | None = field(
|
||||
@@ -81,6 +81,10 @@ class TrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "Learning rate scheduler configuration for training."},
|
||||
)
|
||||
seed: int = field(
|
||||
default=42,
|
||||
metadata={"help": "Random seed that will be set at the beginning of training."},
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.dist_config = get_plugin_config(self.dist_config)
|
||||
|
||||
@@ -76,7 +76,7 @@ class BaseTrainer:
|
||||
if self.args.enable_activation_checkpointing:
|
||||
self.model.gradient_checkpointing_enable({"use_reentrant": False})
|
||||
|
||||
self._accelerate_engine = None
|
||||
self._deepspeed_engine = None
|
||||
dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
|
||||
|
||||
if dist_name == "deepspeed":
|
||||
@@ -108,6 +108,7 @@ class BaseTrainer:
|
||||
cutoff_len=self.args.cutoff_len,
|
||||
batching_workers=self.args.batching_workers,
|
||||
batching_strategy=self.args.batching_strategy,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
|
||||
def _shard_model(self) -> None:
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils.data import default_collate
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
@@ -71,6 +72,7 @@ class BatchGenerator(Iterator):
|
||||
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
|
||||
pin_memory: bool = True,
|
||||
drop_last: bool = True,
|
||||
seed: int = 42,
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.renderer = renderer
|
||||
@@ -82,6 +84,7 @@ class BatchGenerator(Iterator):
|
||||
self.batching_strategy = batching_strategy
|
||||
self.pin_memory = pin_memory
|
||||
self.drop_last = drop_last
|
||||
self.seed = seed
|
||||
# TODO: support length and infinity
|
||||
dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||
|
||||
@@ -128,12 +131,15 @@ class BatchGenerator(Iterator):
|
||||
num_replicas=DistributedInterface().get_world_size(Dim.DP),
|
||||
rank=DistributedInterface().get_rank(Dim.DP),
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
seed=self.seed,
|
||||
drop_last=self.drop_last,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||
|
||||
generato_seed = torch.Generator()
|
||||
generato_seed.manual_seed(self.seed)
|
||||
|
||||
self._data_provider = StatefulDataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.micro_batch_size * self.num_micro_batch,
|
||||
@@ -143,6 +149,7 @@ class BatchGenerator(Iterator):
|
||||
pin_memory=self.pin_memory,
|
||||
pin_memory_device=DistributedInterface().current_device.type,
|
||||
drop_last=self.drop_last,
|
||||
generator=generato_seed,
|
||||
)
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
self._length = len(self._data_provider)
|
||||
|
||||
@@ -166,12 +166,11 @@ class FSDP2Engine:
|
||||
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
|
||||
)
|
||||
|
||||
use_gradient_checkpointing = True # Could be configurable
|
||||
if use_gradient_checkpointing:
|
||||
# BaseTrainer is the single source of truth for gradient checkpointing.
|
||||
# FSDP2 only applies the input-grad compatibility hook when checkpointing is already enabled.
|
||||
if getattr(model, "is_gradient_checkpointing", False):
|
||||
if self.rank == 0:
|
||||
logger.info("Enabling gradient checkpointing (transformers native)...")
|
||||
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
logger.info("Gradient checkpointing is enabled. Applying FSDP2 input grad preparation.")
|
||||
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
|
||||
@@ -15,12 +15,22 @@
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import set_seed as hf_set_seed
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from .constants import IGNORE_INDEX
|
||||
from .types import BatchInput, ModelInput, Processor, Tensor
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
"""Set seed for reproducibility.
|
||||
|
||||
Args:
|
||||
seed: Random seed.
|
||||
"""
|
||||
hf_set_seed(seed)
|
||||
|
||||
|
||||
def is_tokenizer(processor: Processor) -> bool:
|
||||
"""Check if processor is tokenizer.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user