[data] fix mm pluigin for qwen omni video training (#9388)
Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
@@ -68,6 +68,8 @@ if TYPE_CHECKING:
|
|||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||||
from transformers.image_processing_utils import BaseImageProcessor
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
|
from transformers.video_processing_utils import BaseVideoProcessor
|
||||||
|
|
||||||
|
|
||||||
class EncodedImage(TypedDict):
|
class EncodedImage(TypedDict):
|
||||||
path: Optional[str]
|
path: Optional[str]
|
||||||
@@ -1482,6 +1484,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
processor: "MMProcessor",
|
processor: "MMProcessor",
|
||||||
) -> dict[str, "torch.Tensor"]:
|
) -> dict[str, "torch.Tensor"]:
|
||||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||||
|
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
images = self._regularize_images(
|
images = self._regularize_images(
|
||||||
@@ -1499,7 +1502,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
video_fps=getattr(processor, "video_fps", 2.0),
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
|
mm_inputs.update(video_processor(videos=video_data["videos"], return_tensors="pt"))
|
||||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||||
if "second_per_grid_ts" in processor.model_input_names:
|
if "second_per_grid_ts" in processor.model_input_names:
|
||||||
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
|
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in video_data["fps_per_video"]]
|
||||||
@@ -1818,6 +1821,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
processor: "MMProcessor",
|
processor: "MMProcessor",
|
||||||
) -> dict[str, "torch.Tensor"]:
|
) -> dict[str, "torch.Tensor"]:
|
||||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||||
|
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
||||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
@@ -1836,7 +1840,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
video_fps=getattr(processor, "video_fps", 2.0),
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
|
mm_inputs.update(video_processor(videos=video_dict["videos"], return_tensors="pt"))
|
||||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||||
mm_inputs["video_second_per_grid"] = torch.tensor(
|
mm_inputs["video_second_per_grid"] = torch.tensor(
|
||||||
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
|
[temporal_patch_size / fps for fps in video_dict["fps_per_video"]]
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ def launch():
|
|||||||
if is_env_enabled("USE_MCA"):
|
if is_env_enabled("USE_MCA"):
|
||||||
# force use torchrun
|
# force use torchrun
|
||||||
os.environ["FORCE_TORCHRUN"] = "1"
|
os.environ["FORCE_TORCHRUN"] = "1"
|
||||||
|
|
||||||
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
||||||
# launch distributed training
|
# launch distributed training
|
||||||
nnodes = os.getenv("NNODES", "1")
|
nnodes = os.getenv("NNODES", "1")
|
||||||
|
|||||||
Reference in New Issue
Block a user