From db2f794f7b110e6171a533639220551edcf43a28 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Mon, 19 Jan 2026 14:55:16 +0800 Subject: [PATCH] [misc] update mcore related docker and mca supported models (#10114) --- docker/docker-cuda/Dockerfile.megatron | 31 +++++++++++++------------- src/llamafactory/extras/constants.py | 1 + src/llamafactory/hparams/parser.py | 4 ++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docker/docker-cuda/Dockerfile.megatron b/docker/docker-cuda/Dockerfile.megatron index 8b44434d4..6e6e1cb29 100644 --- a/docker/docker-cuda/Dockerfile.megatron +++ b/docker/docker-cuda/Dockerfile.megatron @@ -1,12 +1,13 @@ -# NVIDIA official image (ubuntu-22.04 + cuda-12.4 + python-3.10) -# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html -FROM nvcr.io/nvidia/pytorch:24.05-py3 +# NVIDIA official image (ubuntu-24.04 + cuda-12.9.1 + python-3.12) +# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-25-06.html +FROM nvcr.io/nvidia/pytorch:25.06-py3 ENV DEBIAN_FRONTEND=noninteractive ENV PIP_ROOT_USER_ACTION=ignore ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/ ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ +ENV PIP_CONSTRAINT="" RUN pip install --upgrade pip setuptools wheel "hatchling>=1.18.0" editables --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR} @@ -14,20 +15,14 @@ RUN pip uninstall -y torch torchvision torch-tensorrt \ flash_attn transformer-engine \ cudf dask-cuda cugraph cugraph-service-server cuml raft-dask cugraph-dgl cugraph-pyg dask-cudf -RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 +RUN pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu129 RUN pip uninstall -y opencv opencv-python opencv-python-headless && \ - rm -rf /usr/local/lib/python3.10/dist-packages/cv2/ && \ + rm -rf /usr/local/lib/python3.12/dist-packages/cv2/ && \ pip install opencv-python-headless==4.11.0.86 --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR} -RUN pip install "numpy==1.26.4" "optree>=0.13.0" "spacy==3.7.5" "weasel==0.4.1" \ - transformer-engine[pytorch]==2.2.0 megatron-core==0.13.0 deepspeed==0.16.4 \ - --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR} - -RUN pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl - -# RUN pip install vllm==0.8.4 \ -# --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR} +RUN pip install --trusted-host mirrors.aliyun.com --index-url ${PYPI_MIRROR} \ + "megatron-core>=0.13.0,<0.14.0" "deepspeed==0.16.4" WORKDIR /build @@ -37,6 +32,8 @@ RUN pip uninstall -y apex && \ pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \ --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 32" ${apex_url} +RUN pip install --no-build-isolation transformer_engine[pytorch] + RUN rm -rf /build WORKDIR /workspace @@ -53,11 +50,13 @@ RUN apt-get update && apt-get install -y zip RUN apt-get install -y openjdk-21-jdk ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64 -# pip install LLaMA-Factory +ARG REPO_URL=https://github.com/hiyouga/LlamaFactory.git +ARG BRANCH=main WORKDIR /app -# Copy the application into the image -COPY . /app +# Clone the repository +RUN git clone --depth 1 --branch ${BRANCH} ${REPO_URL} /app || \ + git clone --depth 1 ${REPO_URL} /app # Install LLaMA Factory RUN pip install --no-cache-dir -e . --no-build-isolation && \ diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index fdfaa46f2..6baf51435 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -57,6 +57,7 @@ LLAMABOARD_CONFIG = "llamaboard_config.yaml" MCA_SUPPORTED_MODELS = { "deepseek_v3", + "glm4_moe", "llama", "mistral", "mixtral", diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index ae8b3e424..5cb438919 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -340,7 +340,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo): raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.") - if training_args.fp8 and model_args.quantization_bit is not None: + if not finetuning_args.use_mca and training_args.fp8 and model_args.quantization_bit is not None: raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.") if model_args.infer_backend != EngineName.HF: @@ -359,7 +359,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS _verify_model_args(model_args, data_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args, training_args) - if training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8: + if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8: logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.") model_args.fp8 = True