mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-21 12:03:08 +00:00
Compare commits
37 Commits
1d5e8ebcd0
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
833f6027b1 | ||
|
|
d91d8af89e | ||
|
|
e67ab9e2f2 | ||
|
|
2c4f121817 | ||
|
|
487f8b8191 | ||
|
|
78cad1e332 | ||
|
|
70653026f5 | ||
|
|
246192abd2 | ||
|
|
0258dc14d0 | ||
|
|
3045adf0ba | ||
|
|
a3d44e3152 | ||
|
|
edeb953bc7 | ||
|
|
d045794387 | ||
|
|
9501c3308a | ||
|
|
0ee1c42c2b | ||
|
|
3061f48d55 | ||
|
|
2d9bd2aa14 | ||
|
|
c0245c43fc | ||
|
|
eb976d75a2 | ||
|
|
b5cb7cb0e6 | ||
|
|
0779846513 | ||
|
|
45d335c709 | ||
|
|
816480012f | ||
|
|
d3bf882e87 | ||
|
|
589da21d32 | ||
|
|
122cd46084 | ||
|
|
2b8b871475 | ||
|
|
aab9b400bb | ||
|
|
50599c719b | ||
|
|
a0f3ad0cee | ||
|
|
f80e15dbb4 | ||
|
|
991267fd3b | ||
|
|
5c52afa30d | ||
|
|
675ce8cc7f | ||
|
|
ab073f4c13 | ||
|
|
184304b5b4 | ||
|
|
d3ebd5678d |
10
.github/workflows/docs.yml
vendored
10
.github/workflows/docs.yml
vendored
@@ -25,16 +25,16 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip install -r docs/requirements.txt
|
pip install -r docs/requirements.txt
|
||||||
|
|
||||||
- name: Build Sphinx
|
- name: Build Sphinx
|
||||||
run: |
|
run: |
|
||||||
sphinx-build -b html docs/zh docs/_build/html/zh
|
sphinx-build -b html docs/zh docs/_build/html/zh
|
||||||
@@ -56,10 +56,10 @@ jobs:
|
|||||||
> docs/_build/html/index.html
|
> docs/_build/html/index.html
|
||||||
|
|
||||||
touch docs/_build/html/.nojekyll
|
touch docs/_build/html/.nojekyll
|
||||||
|
|
||||||
- name: Setup Pages
|
- name: Setup Pages
|
||||||
uses: actions/configure-pages@v5
|
uses: actions/configure-pages@v5
|
||||||
|
|
||||||
- name: Upload artifact
|
- name: Upload artifact
|
||||||
uses: actions/upload-pages-artifact@v3
|
uses: actions/upload-pages-artifact@v3
|
||||||
with:
|
with:
|
||||||
|
|||||||
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -35,15 +35,12 @@ jobs:
|
|||||||
transformers:
|
transformers:
|
||||||
- ""
|
- ""
|
||||||
include: # test backward compatibility
|
include: # test backward compatibility
|
||||||
- python: "3.11"
|
|
||||||
os: "ubuntu-latest"
|
|
||||||
transformers: "4.51.0"
|
|
||||||
- python: "3.11"
|
|
||||||
os: "ubuntu-latest"
|
|
||||||
transformers: "4.53.0"
|
|
||||||
- python: "3.11"
|
- python: "3.11"
|
||||||
os: "ubuntu-latest"
|
os: "ubuntu-latest"
|
||||||
transformers: "4.55.0"
|
transformers: "4.55.0"
|
||||||
|
- python: "3.11"
|
||||||
|
os: "ubuntu-latest"
|
||||||
|
transformers: "4.57.1"
|
||||||
|
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
|||||||
1
.github/workflows/tests_cuda.yml
vendored
1
.github/workflows/tests_cuda.yml
vendored
@@ -61,6 +61,7 @@ jobs:
|
|||||||
uv venv
|
uv venv
|
||||||
uv pip install -e .
|
uv pip install -e .
|
||||||
uv pip install -r requirements/dev.txt
|
uv pip install -r requirements/dev.txt
|
||||||
|
uv pip install -r requirements/bitsandbytes.txt
|
||||||
|
|
||||||
- name: Check quality
|
- name: Check quality
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -291,7 +291,7 @@ Read technical notes:
|
|||||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||||
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
|
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
|
||||||
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
|
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
|
||||||
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small |
|
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small|
|
||||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||||
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
||||||
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||||
@@ -319,6 +319,7 @@ Read technical notes:
|
|||||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||||
|
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
|
||||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||||
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
||||||
@@ -472,7 +473,7 @@ huggingface-cli login
|
|||||||
|
|
||||||
| Mandatory | Minimum | Recommend |
|
| Mandatory | Minimum | Recommend |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.9 | 3.10 |
|
| python | 3.11 | >=3.11 |
|
||||||
| torch | 2.0.0 | 2.6.0 |
|
| torch | 2.0.0 | 2.6.0 |
|
||||||
| torchvision | 0.15.0 | 0.21.0 |
|
| torchvision | 0.15.0 | 0.21.0 |
|
||||||
| transformers | 4.49.0 | 4.50.0 |
|
| transformers | 4.49.0 | 4.50.0 |
|
||||||
|
|||||||
@@ -293,7 +293,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||||
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
|
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
|
||||||
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
|
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
|
||||||
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small |
|
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small|
|
||||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||||
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
||||||
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||||
@@ -321,6 +321,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||||
|
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
|
||||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||||
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
||||||
@@ -474,7 +475,7 @@ huggingface-cli login
|
|||||||
|
|
||||||
| 必需项 | 至少 | 推荐 |
|
| 必需项 | 至少 | 推荐 |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.9 | 3.10 |
|
| python | 3.11 | >=3.11 |
|
||||||
| torch | 2.0.0 | 2.6.0 |
|
| torch | 2.0.0 | 2.6.0 |
|
||||||
| torchvision | 0.15.0 | 0.21.0 |
|
| torchvision | 0.15.0 | 0.21.0 |
|
||||||
| transformers | 4.49.0 | 4.50.0 |
|
| transformers | 4.49.0 | 4.50.0 |
|
||||||
|
|||||||
@@ -236,6 +236,13 @@
|
|||||||
"ms_hub_url": "AI-ModelScope/sharegpt_gpt4",
|
"ms_hub_url": "AI-ModelScope/sharegpt_gpt4",
|
||||||
"formatting": "sharegpt"
|
"formatting": "sharegpt"
|
||||||
},
|
},
|
||||||
|
"sgsc_b2b_entities": {
|
||||||
|
"hf_hub_url": "Nooxus-AI/NOO-Verified-Global-Entities",
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "messages"
|
||||||
|
}
|
||||||
|
},
|
||||||
"ultrachat_200k": {
|
"ultrachat_200k": {
|
||||||
"hf_hub_url": "HuggingFaceH4/ultrachat_200k",
|
"hf_hub_url": "HuggingFaceH4/ultrachat_200k",
|
||||||
"ms_hub_url": "AI-ModelScope/ultrachat_200k",
|
"ms_hub_url": "AI-ModelScope/ultrachat_200k",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# https://hub.docker.com/r/ascendai/cann/tags
|
# https://hub.docker.com/r/ascendai/cann/tags
|
||||||
|
|
||||||
ARG BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11
|
ARG BASE_IMAGE=quay.io/ascend/cann:8.5.1-910b-ubuntu22.04-py3.11
|
||||||
FROM ${BASE_IMAGE}
|
FROM ${BASE_IMAGE}
|
||||||
|
|
||||||
# Installation arguments
|
# Installation arguments
|
||||||
@@ -33,9 +33,11 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
|||||||
COPY . /app
|
COPY . /app
|
||||||
|
|
||||||
# Install torch-npu
|
# Install torch-npu
|
||||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
RUN pip uninstall -y torch torchvision torchaudio
|
||||||
pip install --no-cache-dir -e . --no-build-isolation && \
|
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
|
||||||
|
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
|
||||||
|
RUN pip install --no-cache-dir -e . --no-build-isolation && \
|
||||||
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||||
|
|
||||||
# Set up volumes
|
# Set up volumes
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ services:
|
|||||||
dockerfile: ./docker/docker-npu/Dockerfile
|
dockerfile: ./docker/docker-npu/Dockerfile
|
||||||
context: ../..
|
context: ../..
|
||||||
args:
|
args:
|
||||||
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
|
BASE_IMAGE: quay.io/ascend/cann:8.5.1-a3-ubuntu22.04-py3.11
|
||||||
PIP_INDEX: https://pypi.org/simple
|
PIP_INDEX: https://pypi.org/simple
|
||||||
container_name: llamafactory-a3
|
container_name: llamafactory-a3
|
||||||
image: llamafactory:npu-a3
|
image: llamafactory:npu-a3
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
# https://hub.docker.com/r/rocm/pytorch/tags
|
# https://hub.docker.com/r/rocm/pytorch/tags
|
||||||
ARG BASE_IMAGE=rocm/pytorch:rocm6.4.1_ubuntu22.04_py3.10_pytorch_release_2.6.0
|
# ROCm 7.2 + PyTorch 2.7.1 (Python 3.12). Keep base image's PyTorch; do not reinstall.
|
||||||
|
ARG BASE_IMAGE=rocm/pytorch:rocm7.2_ubuntu24.04_py3.12_pytorch_release_2.7.1
|
||||||
FROM ${BASE_IMAGE}
|
FROM ${BASE_IMAGE}
|
||||||
|
|
||||||
# Installation arguments
|
# Installation arguments
|
||||||
ARG PIP_INDEX=https://pypi.org/simple
|
ARG PIP_INDEX=https://pypi.org/simple
|
||||||
ARG INSTALL_FLASHATTN=false
|
ARG INSTALL_FLASHATTN=false
|
||||||
ARG HTTP_PROXY=""
|
ARG HTTP_PROXY=""
|
||||||
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
|
|
||||||
|
|
||||||
# Define environments
|
# Define environments
|
||||||
ENV MAX_JOBS=16
|
ENV MAX_JOBS=16
|
||||||
@@ -32,10 +32,9 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
|||||||
# Copy the application into the image
|
# Copy the application into the image
|
||||||
COPY . /app
|
COPY . /app
|
||||||
|
|
||||||
# Reinstall pytorch rocm and install LLaMA Factory
|
# Install LLaMA Factory (use base image's PyTorch/ROCm; do not reinstall)
|
||||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
RUN pip install --no-cache-dir -e . --pre && \
|
||||||
pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \
|
pip install --no-cache-dir -r requirements/deepspeed.txt -r requirements/liger-kernel.txt -r requirements/bitsandbytes.txt
|
||||||
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
|
|
||||||
|
|
||||||
# Rebuild flash attention
|
# Rebuild flash attention
|
||||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||||
|
|||||||
1
docs/_static/css/lang-switcher.css
vendored
1
docs/_static/css/lang-switcher.css
vendored
@@ -47,4 +47,3 @@
|
|||||||
border-color: rgba(255, 255, 255, 0.45);
|
border-color: rgba(255, 255, 255, 0.45);
|
||||||
box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.12);
|
box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.12);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
28
docs/conf.py
28
docs/conf.py
@@ -1,33 +1,31 @@
|
|||||||
# Configuration file for the Sphinx documentation builder.
|
# Configuration file for the Sphinx documentation builder.
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# Define common settings here
|
# Define common settings here
|
||||||
project = 'LlamaFactory'
|
project = "LlamaFactory"
|
||||||
copyright = '2024, LlamaFactory Team'
|
copyright = "2024, LlamaFactory Team"
|
||||||
author = 'LlamaFactory Team'
|
author = "LlamaFactory Team"
|
||||||
|
|
||||||
extensions = [
|
extensions = [
|
||||||
'sphinx.ext.autodoc',
|
"sphinx.ext.autodoc",
|
||||||
'sphinx.ext.viewcode',
|
"sphinx.ext.viewcode",
|
||||||
'sphinx.ext.napoleon',
|
"sphinx.ext.napoleon",
|
||||||
'myst_parser',
|
"myst_parser",
|
||||||
]
|
]
|
||||||
|
|
||||||
templates_path = ['_templates']
|
templates_path = ["_templates"]
|
||||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||||
|
|
||||||
html_theme = 'sphinx_rtd_theme'
|
html_theme = "sphinx_rtd_theme"
|
||||||
|
|
||||||
html_static_path = ['_static']
|
html_static_path = ["_static"]
|
||||||
|
|
||||||
html_js_files = [
|
html_js_files = [
|
||||||
'js/switcher.js',
|
"js/switcher.js",
|
||||||
]
|
]
|
||||||
|
|
||||||
html_css_files = [
|
html_css_files = [
|
||||||
'css/lang-switcher.css',
|
"css/lang-switcher.css",
|
||||||
]
|
]
|
||||||
|
|
||||||
myst_enable_extensions = [
|
myst_enable_extensions = [
|
||||||
|
|||||||
@@ -1,20 +1,22 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# Add parent dir to path to allow importing conf.py
|
|
||||||
sys.path.insert(0, os.path.abspath('..'))
|
|
||||||
|
|
||||||
from conf import *
|
# Add parent dir to path to allow importing conf.py
|
||||||
|
sys.path.insert(0, os.path.abspath(".."))
|
||||||
|
|
||||||
|
from conf import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
# Language settings
|
# Language settings
|
||||||
language = 'en'
|
language = "en"
|
||||||
html_search_language = 'en'
|
html_search_language = "en"
|
||||||
|
|
||||||
# Static files
|
# Static files
|
||||||
# Point to the root _static directory
|
# Point to the root _static directory
|
||||||
html_static_path = ['../_static']
|
html_static_path = ["../_static"]
|
||||||
|
|
||||||
# Add custom JS for language switcher
|
# Add custom JS for language switcher
|
||||||
html_js_files = [
|
html_js_files = [
|
||||||
'js/switcher.js',
|
"js/switcher.js",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,20 +1,22 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# Add parent dir to path to allow importing conf.py
|
|
||||||
sys.path.insert(0, os.path.abspath('..'))
|
|
||||||
|
|
||||||
from conf import *
|
# Add parent dir to path to allow importing conf.py
|
||||||
|
sys.path.insert(0, os.path.abspath(".."))
|
||||||
|
|
||||||
|
from conf import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
# Language settings
|
# Language settings
|
||||||
language = 'zh_CN'
|
language = "zh_CN"
|
||||||
html_search_language = 'zh'
|
html_search_language = "zh"
|
||||||
|
|
||||||
# Static files
|
# Static files
|
||||||
# Point to the root _static directory
|
# Point to the root _static directory
|
||||||
html_static_path = ['../_static']
|
html_static_path = ["../_static"]
|
||||||
|
|
||||||
# Add custom JS for language switcher
|
# Add custom JS for language switcher
|
||||||
html_js_files = [
|
html_js_files = [
|
||||||
'js/switcher.js',
|
"js/switcher.js",
|
||||||
]
|
]
|
||||||
|
|||||||
45
examples/extras/asft/llama2_full_asft.yaml
Normal file
45
examples/extras/asft/llama2_full_asft.yaml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: models/Llama-2-7b
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||||
|
use_asft_loss: true
|
||||||
|
asft_alpha: 0.1
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: med
|
||||||
|
template: llama2
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 10000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/llama2-7b/full/asft2
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 4
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 2.0e-5
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
# val_size: 0.1
|
||||||
|
# per_device_eval_batch_size: 1
|
||||||
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
45
examples/extras/asft/qwen2_full_asft.yaml
Normal file
45
examples/extras/asft/qwen2_full_asft.yaml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: models/Qwen2.5-7B
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||||
|
use_asft_loss: true
|
||||||
|
asft_alpha: 0.05
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: math
|
||||||
|
template: qwen
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 10000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen2-7b/full/asft
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
save_only_model: false
|
||||||
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 4
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 5.0e-5
|
||||||
|
num_train_epochs: 1.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
# val_size: 0.1
|
||||||
|
# per_device_eval_batch_size: 1
|
||||||
|
# eval_strategy: steps
|
||||||
|
# eval_steps: 500
|
||||||
@@ -28,12 +28,7 @@ save_only_model: false
|
|||||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||||
|
|
||||||
### ray
|
### ray
|
||||||
ray_run_name: qwen3_4b_sft_lora
|
|
||||||
ray_storage_path: ./saves
|
|
||||||
ray_num_workers: 4 # Number of GPUs to use.
|
ray_num_workers: 4 # Number of GPUs to use.
|
||||||
placement_strategy: PACK
|
|
||||||
resources_per_worker:
|
|
||||||
GPU: 1
|
|
||||||
# ray_init_kwargs:
|
# ray_init_kwargs:
|
||||||
# runtime_env:
|
# runtime_env:
|
||||||
# env_vars:
|
# env_vars:
|
||||||
|
|||||||
38
examples/v1/train_freeze/train_freeze_sft.yaml
Normal file
38
examples/v1/train_freeze/train_freeze_sft.yaml
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
model: Qwen/Qwen3-4B
|
||||||
|
trust_remote_code: true
|
||||||
|
model_class: llm
|
||||||
|
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
# Freeze Configuration
|
||||||
|
peft_config:
|
||||||
|
name: freeze
|
||||||
|
freeze_trainable_layers: 2 # Train the last 2 layers
|
||||||
|
freeze_trainable_modules: all # In these layers, train specific modules
|
||||||
|
freeze_extra_modules: null # Extra modules to train (e.g. embed_tokens, lm_head)
|
||||||
|
|
||||||
|
# Kernel Config
|
||||||
|
kernel_config:
|
||||||
|
name: auto
|
||||||
|
include_kernels: auto
|
||||||
|
|
||||||
|
# FSDP Config
|
||||||
|
dist_config:
|
||||||
|
name: fsdp2
|
||||||
|
dcp_path: null
|
||||||
|
|
||||||
|
### data
|
||||||
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
|
### training
|
||||||
|
output_dir: ./outputs/test_freeze
|
||||||
|
micro_batch_size: 1
|
||||||
|
global_batch_size: 4
|
||||||
|
cutoff_len: 2048
|
||||||
|
learning_rate: 2.0e-5
|
||||||
|
bf16: false
|
||||||
|
max_steps: 10
|
||||||
|
|
||||||
|
### sample
|
||||||
|
sample_backend: hf
|
||||||
|
max_new_tokens: 128
|
||||||
24
examples/v1/train_full/train_full_deepspeed.yaml
Normal file
24
examples/v1/train_full/train_full_deepspeed.yaml
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
model: Qwen/Qwen3-0.6B
|
||||||
|
|
||||||
|
model_class: llm
|
||||||
|
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
kernel_config:
|
||||||
|
name: auto
|
||||||
|
include_kernels: auto
|
||||||
|
|
||||||
|
dist_config:
|
||||||
|
name: deepspeed
|
||||||
|
config_file: examples/deepspeed/ds_z3_config.json
|
||||||
|
|
||||||
|
### data
|
||||||
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
|
### training
|
||||||
|
output_dir: outputs/Qwen3-0.6B-deepspeed
|
||||||
|
micro_batch_size: 1
|
||||||
|
cutoff_len: 2048
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
bf16: true
|
||||||
|
max_steps: 10
|
||||||
@@ -14,16 +14,12 @@ dist_config:
|
|||||||
name: fsdp2
|
name: fsdp2
|
||||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||||
|
|
||||||
init_config:
|
|
||||||
name: init_on_meta
|
|
||||||
|
|
||||||
### data
|
### data
|
||||||
train_dataset: data/v1_sft_demo.yaml
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
### training
|
### training
|
||||||
output_dir: outputs/test_fsdp2
|
output_dir: outputs/test_fsdp2
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
global_batch_size: 1
|
|
||||||
cutoff_len: 2048
|
cutoff_len: 2048
|
||||||
learning_rate: 1.0e-4
|
learning_rate: 1.0e-4
|
||||||
bf16: false
|
bf16: false
|
||||||
|
|||||||
7
examples/v1/train_lora/export_lora.yaml
Normal file
7
examples/v1/train_lora/export_lora.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
model: Qwen/Qwen3-4B
|
||||||
|
peft_config:
|
||||||
|
name: lora
|
||||||
|
adapter_name_or_path: ./outputs/test_lora
|
||||||
|
export_dir: ./merge_lora_model
|
||||||
|
export_size: 5
|
||||||
|
infer_dtype: auto
|
||||||
39
examples/v1/train_lora/train_lora_sft.yaml
Normal file
39
examples/v1/train_lora/train_lora_sft.yaml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
model: Qwen/Qwen3-4B
|
||||||
|
trust_remote_code: true
|
||||||
|
model_class: llm
|
||||||
|
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
# PEFT Configuration
|
||||||
|
peft_config:
|
||||||
|
name: lora
|
||||||
|
r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.05
|
||||||
|
target_modules: all
|
||||||
|
|
||||||
|
# Kernel Config
|
||||||
|
kernel_config:
|
||||||
|
name: auto
|
||||||
|
include_kernels: auto
|
||||||
|
|
||||||
|
# FSDP Config
|
||||||
|
dist_config:
|
||||||
|
name: fsdp2
|
||||||
|
dcp_path: null
|
||||||
|
|
||||||
|
### data
|
||||||
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
|
### training
|
||||||
|
output_dir: ./outputs/test_lora
|
||||||
|
micro_batch_size: 1
|
||||||
|
global_batch_size: 4
|
||||||
|
cutoff_len: 2048
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
bf16: true
|
||||||
|
max_steps: 10
|
||||||
|
|
||||||
|
### sample
|
||||||
|
sample_backend: hf
|
||||||
|
max_new_tokens: 128
|
||||||
43
examples/v1/train_qlora/quantization.yaml
Normal file
43
examples/v1/train_qlora/quantization.yaml
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
model: Qwen/Qwen3-0.6B
|
||||||
|
trust_remote_code: true
|
||||||
|
model_class: llm
|
||||||
|
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
# PEFT Configuration
|
||||||
|
peft_config:
|
||||||
|
name: lora
|
||||||
|
r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.05
|
||||||
|
target_modules: all
|
||||||
|
|
||||||
|
# Kernel Config
|
||||||
|
kernel_config:
|
||||||
|
name: auto
|
||||||
|
include_kernels: auto
|
||||||
|
|
||||||
|
# FSDP Config
|
||||||
|
dist_config:
|
||||||
|
name: fsdp2
|
||||||
|
dcp_path: null
|
||||||
|
|
||||||
|
# Quantization Config
|
||||||
|
quant_config:
|
||||||
|
name: bnb # choice: auto/bnb if auto is selected, the quantization method will be automatically selected based on the model and environment.
|
||||||
|
quantization_bit: 4 # choice: 8/4(bnb)
|
||||||
|
|
||||||
|
### data
|
||||||
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
|
### training
|
||||||
|
output_dir: outputs/test_quantization
|
||||||
|
micro_batch_size: 1
|
||||||
|
cutoff_len: 2048
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
bf16: false
|
||||||
|
max_steps: 10
|
||||||
|
|
||||||
|
### sample
|
||||||
|
sample_backend: hf
|
||||||
|
max_new_tokens: 128
|
||||||
@@ -40,7 +40,7 @@ dependencies = [
|
|||||||
"torch>=2.4.0",
|
"torch>=2.4.0",
|
||||||
"torchvision>=0.19.0",
|
"torchvision>=0.19.0",
|
||||||
"torchaudio>=2.4.0",
|
"torchaudio>=2.4.0",
|
||||||
"transformers>=4.51.0,<=5.0.0,!=4.52.0,!=4.57.0",
|
"transformers>=4.55.0,<=5.2.0,!=4.52.0,!=4.57.0",
|
||||||
"datasets>=2.16.0,<=4.0.0",
|
"datasets>=2.16.0,<=4.0.0",
|
||||||
"accelerate>=1.3.0,<=1.11.0",
|
"accelerate>=1.3.0,<=1.11.0",
|
||||||
"peft>=0.18.0,<=0.18.1",
|
"peft>=0.18.0,<=0.18.1",
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
liger-kernel>=0.5.5
|
liger-kernel>=0.6.3
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
torch==2.7.1
|
torch==2.7.1
|
||||||
torch-npu==2.7.1
|
torch-npu==2.7.1.post2
|
||||||
torchvision==0.22.1
|
torchvision==0.22.1
|
||||||
torchaudio==2.7.1
|
torchaudio==2.7.1
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ def convert(
|
|||||||
pipeline_model_parallel_size: int = 1,
|
pipeline_model_parallel_size: int = 1,
|
||||||
expert_model_parallel_size: int = 1,
|
expert_model_parallel_size: int = 1,
|
||||||
virtual_pipeline_model_parallel_size: int | None = None,
|
virtual_pipeline_model_parallel_size: int | None = None,
|
||||||
|
moe_grouped_gemm: bool | None = None,
|
||||||
):
|
):
|
||||||
"""Convert checkpoint between MCA and HuggingFace formats.
|
"""Convert checkpoint between MCA and HuggingFace formats.
|
||||||
|
|
||||||
@@ -84,6 +85,10 @@ def convert(
|
|||||||
pipeline_model_parallel_size: Pipeline model parallel size
|
pipeline_model_parallel_size: Pipeline model parallel size
|
||||||
expert_model_parallel_size: Expert model parallel size
|
expert_model_parallel_size: Expert model parallel size
|
||||||
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
|
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
|
||||||
|
moe_grouped_gemm: Use grouped gemm for MoE experts. When enabled, expert
|
||||||
|
weights are stored in a flattened format (linear_fc1.weight0, weight1, ...)
|
||||||
|
rather than per-expert format (local_experts.0.linear_fc1.weight, ...).
|
||||||
|
Must match the format used when saving the checkpoint.
|
||||||
"""
|
"""
|
||||||
if bf16 and fp16:
|
if bf16 and fp16:
|
||||||
raise ValueError("bf16 and fp16 cannot be both True.")
|
raise ValueError("bf16 and fp16 cannot be both True.")
|
||||||
@@ -97,8 +102,9 @@ def convert(
|
|||||||
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
||||||
expert_model_parallel_size=expert_model_parallel_size,
|
expert_model_parallel_size=expert_model_parallel_size,
|
||||||
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
|
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
|
||||||
|
moe_grouped_gemm=moe_grouped_gemm,
|
||||||
|
transformer_impl="transformer_engine", # hard code here since we default using te for training
|
||||||
)
|
)
|
||||||
|
|
||||||
convert_checkpoint_to_mca(
|
convert_checkpoint_to_mca(
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
output_path,
|
output_path,
|
||||||
|
|||||||
@@ -154,25 +154,24 @@ def vllm_infer(
|
|||||||
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
|
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
|
||||||
|
|
||||||
for j in range(len(batch["input_ids"])):
|
for j in range(len(batch["input_ids"])):
|
||||||
|
multi_modal_data = {}
|
||||||
|
video_metadata_kwargs = None
|
||||||
|
|
||||||
if batch["images"][j] is not None:
|
if batch["images"][j] is not None:
|
||||||
image = batch["images"][j]
|
image = batch["images"][j]
|
||||||
multi_modal_data = {
|
multi_modal_data["image"] = template_obj.mm_plugin._regularize_images(
|
||||||
"image": template_obj.mm_plugin._regularize_images(
|
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||||
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
)["images"]
|
||||||
)["images"]
|
|
||||||
}
|
if batch["videos"][j] is not None:
|
||||||
elif batch["videos"][j] is not None:
|
|
||||||
video_metadata, video_metadata_kwargs = None, None
|
|
||||||
video = batch["videos"][j]
|
video = batch["videos"][j]
|
||||||
multi_modal_data = {
|
multi_modal_data["video"] = template_obj.mm_plugin._regularize_videos(
|
||||||
"video": template_obj.mm_plugin._regularize_videos(
|
video,
|
||||||
video,
|
image_max_pixels=image_max_pixels,
|
||||||
image_max_pixels=image_max_pixels,
|
image_min_pixels=image_min_pixels,
|
||||||
image_min_pixels=image_min_pixels,
|
video_fps=video_fps,
|
||||||
video_fps=video_fps,
|
video_maxlen=video_maxlen,
|
||||||
video_maxlen=video_maxlen,
|
)["videos"]
|
||||||
)["videos"]
|
|
||||||
}
|
|
||||||
if need_video_kwargs:
|
if need_video_kwargs:
|
||||||
container = av.open(video[0], "r")
|
container = av.open(video[0], "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
@@ -192,18 +191,17 @@ def vllm_infer(
|
|||||||
video_backend="opencv",
|
video_backend="opencv",
|
||||||
)
|
)
|
||||||
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
|
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
|
||||||
elif batch["audios"][j] is not None:
|
|
||||||
|
if batch["audios"][j] is not None:
|
||||||
audio = batch["audios"][j]
|
audio = batch["audios"][j]
|
||||||
audio_data = template_obj.mm_plugin._regularize_audios(
|
audio_data = template_obj.mm_plugin._regularize_audios(
|
||||||
audio,
|
audio,
|
||||||
sampling_rate=16000,
|
sampling_rate=16000,
|
||||||
)
|
)
|
||||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
|
||||||
else:
|
|
||||||
multi_modal_data = None
|
|
||||||
|
|
||||||
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
|
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data or None}
|
||||||
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
|
if video_metadata_kwargs is not None:
|
||||||
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
|
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
|
||||||
|
|
||||||
vllm_inputs.append(vllm_input_data)
|
vllm_inputs.append(vllm_input_data)
|
||||||
|
|||||||
@@ -88,7 +88,10 @@ def _process_request(
|
|||||||
|
|
||||||
if request.messages[0].role == Role.SYSTEM:
|
if request.messages[0].role == Role.SYSTEM:
|
||||||
content = request.messages.pop(0).content
|
content = request.messages.pop(0).content
|
||||||
system = content[0].text if isinstance(content, list) else content
|
if isinstance(content, list):
|
||||||
|
system = content[0].text if content else ""
|
||||||
|
else:
|
||||||
|
system = content
|
||||||
else:
|
else:
|
||||||
system = None
|
system = None
|
||||||
|
|
||||||
|
|||||||
@@ -180,35 +180,32 @@ class VllmEngine(BaseEngine):
|
|||||||
else self.generating_args["skip_special_tokens"],
|
else self.generating_args["skip_special_tokens"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
multi_modal_data = {}
|
||||||
if images is not None: # add image features
|
if images is not None: # add image features
|
||||||
multi_modal_data = {
|
multi_modal_data["image"] = self.template.mm_plugin._regularize_images(
|
||||||
"image": self.template.mm_plugin._regularize_images(
|
images,
|
||||||
images,
|
image_max_pixels=self.model_args.image_max_pixels,
|
||||||
image_max_pixels=self.model_args.image_max_pixels,
|
image_min_pixels=self.model_args.image_min_pixels,
|
||||||
image_min_pixels=self.model_args.image_min_pixels,
|
)["images"]
|
||||||
)["images"]
|
|
||||||
}
|
if videos is not None:
|
||||||
elif videos is not None:
|
multi_modal_data["video"] = self.template.mm_plugin._regularize_videos(
|
||||||
multi_modal_data = {
|
videos,
|
||||||
"video": self.template.mm_plugin._regularize_videos(
|
image_max_pixels=self.model_args.video_max_pixels,
|
||||||
videos,
|
image_min_pixels=self.model_args.video_min_pixels,
|
||||||
image_max_pixels=self.model_args.video_max_pixels,
|
video_fps=self.model_args.video_fps,
|
||||||
image_min_pixels=self.model_args.video_min_pixels,
|
video_maxlen=self.model_args.video_maxlen,
|
||||||
video_fps=self.model_args.video_fps,
|
)["videos"]
|
||||||
video_maxlen=self.model_args.video_maxlen,
|
|
||||||
)["videos"]
|
if audios is not None:
|
||||||
}
|
|
||||||
elif audios is not None:
|
|
||||||
audio_data = self.template.mm_plugin._regularize_audios(
|
audio_data = self.template.mm_plugin._regularize_audios(
|
||||||
audios,
|
audios,
|
||||||
sampling_rate=self.model_args.audio_sampling_rate,
|
sampling_rate=self.model_args.audio_sampling_rate,
|
||||||
)
|
)
|
||||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
|
||||||
else:
|
|
||||||
multi_modal_data = None
|
|
||||||
|
|
||||||
result_generator = self.model.generate(
|
result_generator = self.model.generate(
|
||||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data or None},
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=self.lora_request,
|
lora_request=self.lora_request,
|
||||||
|
|||||||
@@ -15,6 +15,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import inspect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||||
|
|
||||||
@@ -24,7 +26,7 @@ import torch.nn.functional as F
|
|||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
|
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, MROPE_MODELS
|
||||||
from ..extras.packages import is_pillow_available
|
from ..extras.packages import is_pillow_available
|
||||||
|
|
||||||
|
|
||||||
@@ -38,6 +40,56 @@ if TYPE_CHECKING:
|
|||||||
from .template import Template
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
|
def _slice_mm_inputs_for_sample(
|
||||||
|
mm_inputs: dict[str, Any],
|
||||||
|
batch_imglens: list[int],
|
||||||
|
batch_vidlens: list[int],
|
||||||
|
batch_idx: int,
|
||||||
|
images_per_subseq: Optional[list[int]] = None,
|
||||||
|
videos_per_subseq: Optional[list[int]] = None,
|
||||||
|
subseq_idx: Optional[int] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
r"""Slice mm_inputs for one batch sample, optionally for a single sub-sequence when packing.
|
||||||
|
|
||||||
|
image_grid_thw / video_grid_thw have shape [num_items, 3]. Indices for sample batch_idx
|
||||||
|
are batch_imglens[batch_idx] images and batch_vidlens[batch_idx] videos. When subseq_idx
|
||||||
|
is given, further restrict to that sub-seq's counts via packed_*_counts.
|
||||||
|
has_dummy_image=True means only batch[0] will be concated with fake image and no multimodal data.
|
||||||
|
"""
|
||||||
|
image_start_idx = sum(batch_imglens[:batch_idx])
|
||||||
|
image_end_idx = sum(batch_imglens[: batch_idx + 1])
|
||||||
|
video_start_idx = sum(batch_vidlens[:batch_idx])
|
||||||
|
video_end_idx = sum(batch_vidlens[: batch_idx + 1])
|
||||||
|
|
||||||
|
if subseq_idx is not None and images_per_subseq is not None:
|
||||||
|
image_start_idx += sum(images_per_subseq[:subseq_idx])
|
||||||
|
image_end_idx = image_start_idx + images_per_subseq[subseq_idx]
|
||||||
|
|
||||||
|
if subseq_idx is not None and videos_per_subseq is not None:
|
||||||
|
video_start_idx += sum(videos_per_subseq[:subseq_idx])
|
||||||
|
video_end_idx = video_start_idx + videos_per_subseq[subseq_idx]
|
||||||
|
|
||||||
|
sliced_mm_inputs: dict[str, Any] = {}
|
||||||
|
key_to_slice_meta = {
|
||||||
|
"image_grid_thw": (image_start_idx, image_end_idx, True),
|
||||||
|
"video_grid_thw": (video_start_idx, video_end_idx, True),
|
||||||
|
"second_per_grid_ts": (video_start_idx, video_end_idx, False), # qwen2.5vl
|
||||||
|
"video_second_per_grid": (video_start_idx, video_end_idx, False), # qwen omni
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, (start_idx, end_idx, assign_none_when_empty) in key_to_slice_meta.items():
|
||||||
|
if key not in mm_inputs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mm_value = mm_inputs[key]
|
||||||
|
if mm_value is not None and end_idx > start_idx:
|
||||||
|
sliced_mm_inputs[key] = mm_value[start_idx:end_idx]
|
||||||
|
elif assign_none_when_empty:
|
||||||
|
sliced_mm_inputs[key] = None
|
||||||
|
|
||||||
|
return sliced_mm_inputs
|
||||||
|
|
||||||
|
|
||||||
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||||
r"""Expand 2d attention mask to 4d attention mask.
|
r"""Expand 2d attention mask to 4d attention mask.
|
||||||
|
|
||||||
@@ -105,9 +157,154 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
else:
|
else:
|
||||||
self.get_rope_func = None
|
self.get_rope_func = None
|
||||||
|
|
||||||
|
def _compute_rope_position_ids(
|
||||||
|
self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
r"""Compute position_ids and rope_deltas via get_rope_func for VLMs."""
|
||||||
|
rope_index_kwargs = {
|
||||||
|
"input_ids": features["input_ids"],
|
||||||
|
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
||||||
|
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
||||||
|
"attention_mask": (features["attention_mask"] >= 1).float(),
|
||||||
|
}
|
||||||
|
if features["attention_mask"].sum() == 0:
|
||||||
|
features["position_ids"] = torch.zeros((3, *features["input_ids"].shape))
|
||||||
|
features["rope_deltas"] = torch.zeros(features["input_ids"].shape[0])
|
||||||
|
return
|
||||||
|
|
||||||
|
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
|
||||||
|
image_token_id = getattr(self.model.config, "image_token_id", None)
|
||||||
|
video_token_id = getattr(self.model.config, "video_token_id", None)
|
||||||
|
if image_token_id is not None or video_token_id is not None:
|
||||||
|
mm_token_type_ids = torch.zeros_like(features["input_ids"])
|
||||||
|
if image_token_id is not None:
|
||||||
|
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
|
||||||
|
if video_token_id is not None:
|
||||||
|
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
|
||||||
|
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
|
||||||
|
|
||||||
|
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||||
|
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||||
|
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
||||||
|
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||||
|
|
||||||
|
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
|
||||||
|
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
||||||
|
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||||
|
if feature_attention_mask is not None: # FIXME: need to get video image lengths
|
||||||
|
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||||
|
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
||||||
|
|
||||||
|
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
||||||
|
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
|
||||||
|
dim=-1
|
||||||
|
).unsqueeze(-1)
|
||||||
|
else: # for qwen vl
|
||||||
|
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
||||||
|
|
||||||
|
def _compute_rope_position_ids_with_packing(
|
||||||
|
self,
|
||||||
|
features: dict[str, "torch.Tensor"],
|
||||||
|
mm_inputs: dict[str, Any],
|
||||||
|
packing_params_list: list[dict[str, Any] | None],
|
||||||
|
batch_imglens: list[int],
|
||||||
|
batch_vidlens: list[int],
|
||||||
|
batch_audlens: list[int],
|
||||||
|
has_dummy_image: bool,
|
||||||
|
) -> None:
|
||||||
|
r"""Compute position_ids and rope_deltas per sample (or per sub-sequence when packed), then merge and validate."""
|
||||||
|
bsz = features["input_ids"].size(0)
|
||||||
|
seq_len = features["input_ids"].size(1)
|
||||||
|
all_position_ids: list[torch.Tensor] = []
|
||||||
|
all_rope_deltas: list[torch.Tensor] = []
|
||||||
|
|
||||||
|
if has_dummy_image:
|
||||||
|
# for [0, seq_len] = [0, unpadded_length + right_padding_length + fake_input_ids_len + collator_padding_length]
|
||||||
|
# FIXME: maybe right_padding_length is large, with improper max_cutoff_len
|
||||||
|
unpadded_length = int(features["attention_mask"][0].bool().sum().item())
|
||||||
|
right_padding_length = int((packing_params_list[0] or {}).get("right_padding_length") or 0)
|
||||||
|
fake_input_padding_length = max(0, seq_len - unpadded_length - right_padding_length)
|
||||||
|
dummy_image_right_padding_mrope = torch.zeros((3, bsz, fake_input_padding_length))
|
||||||
|
dummy_image_right_padding_attention_mask = torch.zeros((bsz, fake_input_padding_length))
|
||||||
|
assert self.tokenizer.padding_side == "right", "padding_side should be right when fake image is injected"
|
||||||
|
dummy_mm_inputs = copy.deepcopy(mm_inputs)
|
||||||
|
|
||||||
|
for sample_idx in range(bsz):
|
||||||
|
sample_packing = (packing_params_list[sample_idx] or {}) if sample_idx < len(packing_params_list) else {}
|
||||||
|
sequence_boundaries = sample_packing.get("sequence_boundaries")
|
||||||
|
num_sub_seqs = (len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1
|
||||||
|
image_subseq_ids = sample_packing.get("image_subseq_ids") or []
|
||||||
|
video_subseq_ids = sample_packing.get("video_subseq_ids") or []
|
||||||
|
images_per_subseq = (
|
||||||
|
[image_subseq_ids.count(i) for i in range(num_sub_seqs)] if image_subseq_ids and num_sub_seqs > 1 else None
|
||||||
|
)
|
||||||
|
videos_per_subseq = (
|
||||||
|
[video_subseq_ids.count(i) for i in range(num_sub_seqs)] if video_subseq_ids and num_sub_seqs > 1 else None
|
||||||
|
)
|
||||||
|
if has_dummy_image:
|
||||||
|
mm_inputs = {}
|
||||||
|
|
||||||
|
if num_sub_seqs <= 1:
|
||||||
|
sample_features = {
|
||||||
|
"input_ids": features["input_ids"],
|
||||||
|
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1],
|
||||||
|
}
|
||||||
|
mm_inputs_for_sample = _slice_mm_inputs_for_sample(
|
||||||
|
mm_inputs, batch_imglens, batch_vidlens, sample_idx=sample_idx
|
||||||
|
)
|
||||||
|
self._compute_rope_position_ids(sample_features, mm_inputs_for_sample)
|
||||||
|
all_position_ids.append(sample_features["position_ids"])
|
||||||
|
all_rope_deltas.append(sample_features["rope_deltas"])
|
||||||
|
else:
|
||||||
|
# when we do packing, don't need rope_deltas when training.
|
||||||
|
sample_position_ids: list[torch.Tensor] = []
|
||||||
|
for subseq_idx in range(num_sub_seqs):
|
||||||
|
subseq_start = sequence_boundaries[subseq_idx]
|
||||||
|
subseq_end = sequence_boundaries[subseq_idx + 1]
|
||||||
|
subseq_features = {
|
||||||
|
"input_ids": features["input_ids"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
|
||||||
|
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
|
||||||
|
}
|
||||||
|
mm_inputs_for_subseq = _slice_mm_inputs_for_sample(
|
||||||
|
mm_inputs,
|
||||||
|
batch_imglens,
|
||||||
|
batch_vidlens,
|
||||||
|
sample_idx,
|
||||||
|
images_per_subseq,
|
||||||
|
videos_per_subseq,
|
||||||
|
subseq_idx
|
||||||
|
)
|
||||||
|
self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq)
|
||||||
|
sample_position_ids.append(subseq_features["position_ids"])
|
||||||
|
all_position_ids.append(torch.cat(sample_position_ids, dim=-1))
|
||||||
|
|
||||||
|
batch_dim_for_position_ids = 1 if all_position_ids[0].dim() == 3 else 0
|
||||||
|
|
||||||
|
features["position_ids"] = torch.cat(all_position_ids, dim=batch_dim_for_position_ids)
|
||||||
|
if has_dummy_image:
|
||||||
|
mm_inputs = dummy_mm_inputs
|
||||||
|
|
||||||
|
expected_position_ids_shape = (bsz, seq_len) if all_position_ids[0].dim() == 2 else (
|
||||||
|
all_position_ids[0].size(0),
|
||||||
|
bsz,
|
||||||
|
seq_len,
|
||||||
|
)
|
||||||
|
# Check if position_ids shape matches expected shape.
|
||||||
|
# for further usage, we should padding to the right when some padding token on the right.
|
||||||
|
if has_dummy_image:
|
||||||
|
features["position_ids"] = torch.cat([features["position_ids"], dummy_image_right_padding_mrope], dim=-1)
|
||||||
|
features["attention_mask"] = torch.cat([features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1)
|
||||||
|
|
||||||
|
if features["position_ids"].shape != expected_position_ids_shape:
|
||||||
|
raise ValueError(
|
||||||
|
"Merged position_ids shape mismatch: "
|
||||||
|
f"got {features['position_ids'].shape}, expected {expected_position_ids_shape}."
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||||
batch_images, batch_videos, batch_audios = [], [], []
|
batch_images, batch_videos, batch_audios = [], [], []
|
||||||
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
||||||
|
packing_params_list: list[dict[str, Any] | None] = []
|
||||||
for feature in features:
|
for feature in features:
|
||||||
images = feature.pop("images", None) or []
|
images = feature.pop("images", None) or []
|
||||||
videos = feature.pop("videos", None) or []
|
videos = feature.pop("videos", None) or []
|
||||||
@@ -119,8 +316,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
batch_vidlens.append(len(videos))
|
batch_vidlens.append(len(videos))
|
||||||
batch_audlens.append(len(audios))
|
batch_audlens.append(len(audios))
|
||||||
batch_input_ids.append(feature["input_ids"])
|
batch_input_ids.append(feature["input_ids"])
|
||||||
|
packing_params_list.append(feature.pop("packing_params", None))
|
||||||
|
|
||||||
fake_input_ids = []
|
fake_input_ids = []
|
||||||
|
has_dummy_image = False
|
||||||
if (
|
if (
|
||||||
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
|
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
|
||||||
): # avoid process hanging in zero3/fsdp case
|
): # avoid process hanging in zero3/fsdp case
|
||||||
@@ -136,6 +335,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
fake_input_ids.extend(_fake_input_ids)
|
fake_input_ids.extend(_fake_input_ids)
|
||||||
batch_images = fake_images
|
batch_images = fake_images
|
||||||
batch_imglens[0] = 1
|
batch_imglens[0] = 1
|
||||||
|
has_dummy_image = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
|
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
|
||||||
@@ -182,45 +382,50 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
features: dict[str, torch.Tensor] = super().__call__(features)
|
features: dict[str, torch.Tensor] = super().__call__(features)
|
||||||
|
|
||||||
|
bsz, seq_len = features["input_ids"].shape[:2]
|
||||||
|
model_type = getattr(self.model.config, "model_type", None) if self.model is not None else None
|
||||||
|
is_omni = model_type in [
|
||||||
|
"qwen2_5_omni_thinker",
|
||||||
|
"qwen3_omni_moe_thinker",
|
||||||
|
]
|
||||||
|
|
||||||
if self.get_rope_func is not None:
|
if self.get_rope_func is not None:
|
||||||
rope_index_kwargs = {
|
# for mmrope situation, we should calculate position_ids and rope_deltas per sample.
|
||||||
"input_ids": features["input_ids"],
|
# When neat_packing is on, each sample has packing_params; None means no packing for that sample.
|
||||||
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
boundaries_list = [
|
||||||
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
p.get("sequence_boundaries") if p is not None else None for p in packing_params_list
|
||||||
"attention_mask": (features["attention_mask"] >= 1).float(),
|
]
|
||||||
}
|
has_packing = any(b is not None and len(b) > 2 for b in boundaries_list)
|
||||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
if has_dummy_image and has_packing:
|
||||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
# FIXME: too tricky, need to be refactored
|
||||||
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
features["has_dummy_image"] = True
|
||||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
|
||||||
|
|
||||||
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
|
# When fake image/audio was injected, sequence_boundaries no longer match the tensor; use non-packing path.
|
||||||
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
if not has_packing:
|
||||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
self._compute_rope_position_ids(features, mm_inputs)
|
||||||
if feature_attention_mask is not None: # FIXME: need to get video image lengths
|
else:
|
||||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
if is_omni:
|
||||||
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
raise RuntimeError("Omni models are not supported for packed sequences for now.")
|
||||||
|
|
||||||
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
self._compute_rope_position_ids_with_packing(
|
||||||
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
|
features,
|
||||||
dim=-1
|
mm_inputs,
|
||||||
).unsqueeze(-1)
|
packing_params_list,
|
||||||
else: # for qwen vl
|
batch_imglens,
|
||||||
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
batch_vidlens,
|
||||||
|
batch_audlens,
|
||||||
|
has_dummy_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For transformers compatibility, after https://github.com/huggingface/transformers/issues/39400
|
||||||
|
if features["position_ids"].dim() == 3:
|
||||||
|
features["position_ids"] = torch.cat(
|
||||||
|
[features["position_ids"][0].unsqueeze(0), features["position_ids"]], dim=0
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.model is not None
|
self.model is not None
|
||||||
and getattr(self.model.config, "model_type", None)
|
and getattr(self.model.config, "model_type", None) in MROPE_MODELS
|
||||||
in [
|
|
||||||
"glm4v",
|
|
||||||
"Keye",
|
|
||||||
"qwen2_vl",
|
|
||||||
"qwen2_5_vl",
|
|
||||||
"qwen2_5_omni_thinker",
|
|
||||||
"qwen3_omni_moe_thinker",
|
|
||||||
"qwen3_vl",
|
|
||||||
"qwen3_vl_moe",
|
|
||||||
]
|
|
||||||
and ("position_ids" not in features or features["position_ids"].dim() != 3)
|
and ("position_ids" not in features or features["position_ids"].dim() != 3)
|
||||||
):
|
):
|
||||||
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
|
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
|
||||||
@@ -248,12 +453,51 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
|||||||
block_diag_attn: bool = False
|
block_diag_attn: bool = False
|
||||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||||
compute_dtype: "torch.dtype" = torch.float32
|
compute_dtype: "torch.dtype" = torch.float32
|
||||||
|
neat_packing: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
if self.neat_packing and self.attn_implementation == "flash_attention_2":
|
||||||
|
if self.model is not None and getattr(self.model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe", "gpt_oss"]:
|
||||||
|
raise ValueError("Neat packing is not supported for qwen3_5, qwen3_5_moe, gpt_oss models for now.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unpad_packed_features(features: dict[str, Any]) -> None:
|
||||||
|
r"""Trim padded positions for packed FA2 batches."""
|
||||||
|
attention_mask = features.get("attention_mask")
|
||||||
|
if not torch.is_tensor(attention_mask) or attention_mask.dim() != 2 or attention_mask.size(0) != 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
seq_len = attention_mask.size(1)
|
||||||
|
non_padding_indices = torch.nonzero(attention_mask[0] != 0, as_tuple=False).flatten()
|
||||||
|
if non_padding_indices.numel() == seq_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
keys_on_seq_dim_1 = {"input_ids", "labels", "attention_mask", "token_type_ids"}
|
||||||
|
for key, value in list(features.items()):
|
||||||
|
if not torch.is_tensor(value):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key == "position_ids" and value.size(-1) == seq_len:
|
||||||
|
features[key] = value.index_select(-1, non_padding_indices)
|
||||||
|
elif key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len:
|
||||||
|
features[key] = value.index_select(1, non_padding_indices)
|
||||||
|
elif key in keys_on_seq_dim_1 and value.dim() == 2 and value.size(0) == 1 and value.size(1) == seq_len:
|
||||||
|
features[key] = value.index_select(1, non_padding_indices)
|
||||||
|
|
||||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||||
features = super().__call__(features)
|
features = super().__call__(features)
|
||||||
|
has_dummy_image = features.pop("has_dummy_image", False)
|
||||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||||
|
|
||||||
|
if self.neat_packing and self.attn_implementation == "flash_attention_2": # FIXME compatibility fa3/fa4
|
||||||
|
assert features["input_ids"].shape[0] == 1, "bsz should be 1 for neat packing"
|
||||||
|
if not has_dummy_image:
|
||||||
|
self._unpad_packed_features(features)
|
||||||
|
|
||||||
|
features["attention_mask"] = None # let transformers handle causal packed mask.
|
||||||
|
|
||||||
for key, value in features.items(): # cast data dtype for paligemma
|
for key, value in features.items(): # cast data dtype for paligemma
|
||||||
if torch.is_tensor(value) and torch.is_floating_point(value):
|
if torch.is_tensor(value) and torch.is_floating_point(value):
|
||||||
features[key] = value.to(self.compute_dtype)
|
features[key] = value.to(self.compute_dtype)
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ def read_cloud_json(cloud_path: str) -> list[Any]:
|
|||||||
|
|
||||||
# filter out non-JSON files
|
# filter out non-JSON files
|
||||||
files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
|
files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
|
||||||
files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files)
|
files = list(filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files))
|
||||||
if not files:
|
if not files:
|
||||||
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
|
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
|
||||||
|
|
||||||
|
|||||||
@@ -27,11 +27,12 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, Type
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
|
from transformers.image_utils import get_image_size, is_valid_image, make_flat_list_of_images, to_numpy_array
|
||||||
from transformers.models.mllama.processing_mllama import (
|
from transformers.models.mllama.processing_mllama import (
|
||||||
convert_sparse_cross_attention_mask_to_dense,
|
convert_sparse_cross_attention_mask_to_dense,
|
||||||
get_cross_attention_token_mask,
|
get_cross_attention_token_mask,
|
||||||
)
|
)
|
||||||
|
from transformers.video_utils import make_batched_videos
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
@@ -47,13 +48,6 @@ if is_pyav_available():
|
|||||||
import av
|
import av
|
||||||
|
|
||||||
|
|
||||||
if is_transformers_version_greater_than("4.52.0"):
|
|
||||||
from transformers.image_utils import make_flat_list_of_images
|
|
||||||
from transformers.video_utils import make_batched_videos
|
|
||||||
else:
|
|
||||||
from transformers.image_utils import make_batched_videos, make_flat_list_of_images
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from av.stream import Stream
|
from av.stream import Stream
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
@@ -161,7 +155,9 @@ class MMPluginMixin:
|
|||||||
video_processor: BaseImageProcessor = getattr(
|
video_processor: BaseImageProcessor = getattr(
|
||||||
processor, "video_processor", getattr(processor, "image_processor", None)
|
processor, "video_processor", getattr(processor, "image_processor", None)
|
||||||
)
|
)
|
||||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||||
|
processor, "audio_processor", None
|
||||||
|
)
|
||||||
if len(images) != 0 and self.image_token is None:
|
if len(images) != 0 and self.image_token is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"This model does not support image input. Please check whether the correct `template` is used."
|
"This model does not support image input. Please check whether the correct `template` is used."
|
||||||
@@ -390,7 +386,9 @@ class MMPluginMixin:
|
|||||||
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
||||||
|
|
||||||
if len(audios) != 0:
|
if len(audios) != 0:
|
||||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||||
|
processor, "audio_processor", None
|
||||||
|
)
|
||||||
audios = self._regularize_audios(
|
audios = self._regularize_audios(
|
||||||
audios,
|
audios,
|
||||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||||
@@ -1054,7 +1052,9 @@ class MiniCPMVPlugin(BasePlugin):
|
|||||||
chunk_input=True,
|
chunk_input=True,
|
||||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||||
)
|
)
|
||||||
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
|
audio_feature_lens = [
|
||||||
|
x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens
|
||||||
|
]
|
||||||
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
||||||
if kwargs.get("ret_phs", False):
|
if kwargs.get("ret_phs", False):
|
||||||
mm_inputs.update({"audio_phs": audio_phs})
|
mm_inputs.update({"audio_phs": audio_phs})
|
||||||
@@ -1094,7 +1094,7 @@ class MiniCPMVPlugin(BasePlugin):
|
|||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
video_seqlen = len(mm_inputs["image_sizes"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
||||||
num_video_tokens += 1
|
num_video_tokens += 1
|
||||||
|
|
||||||
@@ -1876,7 +1876,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
) -> dict[str, "torch.Tensor"]:
|
) -> dict[str, "torch.Tensor"]:
|
||||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||||
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
||||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||||
|
processor, "audio_processor", None
|
||||||
|
)
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
images = self._regularize_images(
|
images = self._regularize_images(
|
||||||
@@ -1981,6 +1983,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
|
f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
position_id_per_seconds: int = getattr(processor, "position_id_per_seconds", 25)
|
||||||
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
||||||
video_t_index = (
|
video_t_index = (
|
||||||
torch.arange(video_grid_thw[num_video_tokens][0])
|
torch.arange(video_grid_thw[num_video_tokens][0])
|
||||||
@@ -1992,9 +1995,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
|||||||
)
|
)
|
||||||
.flatten()
|
.flatten()
|
||||||
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
||||||
* 25 # FIXME hardcode of position_id_per_seconds=25
|
* position_id_per_seconds
|
||||||
).long()
|
).long()
|
||||||
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
t_ntoken_per_chunk = position_id_per_seconds * 2
|
||||||
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
||||||
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
||||||
placeholder_string = ""
|
placeholder_string = ""
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
@@ -27,6 +27,23 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PackingParams:
|
||||||
|
r"""Metadata for a packed sequence: sub-sequence boundaries and multimodal data indices.
|
||||||
|
|
||||||
|
- sequence_boundaries: cumulative token positions, e.g. [0, 100, 250, 512] means 3 sub-seqs
|
||||||
|
with token ranges [0,100), [100,250), [250,512). Length = num_sub_seqs + 1.
|
||||||
|
- image_subseq_ids / video_subseq_ids / audio_subseq_ids: for each mm item, the 0-based
|
||||||
|
sub-sequence index it belongs to. Length = total number of that mm type in the packed sample.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sequence_boundaries: list[int]
|
||||||
|
image_subseq_ids: list[int]
|
||||||
|
video_subseq_ids: list[int]
|
||||||
|
audio_subseq_ids: list[int]
|
||||||
|
right_padding_length: int
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SupervisedDatasetProcessor(DatasetProcessor):
|
class SupervisedDatasetProcessor(DatasetProcessor):
|
||||||
@@ -162,10 +179,17 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
|||||||
valid_num += 1
|
valid_num += 1
|
||||||
|
|
||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
|
requires_packing_params = self.data_args.neat_packing
|
||||||
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
|
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
|
||||||
for knapsack in knapsacks:
|
for knapsack in knapsacks:
|
||||||
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
|
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
|
||||||
packed_images, packed_videos, packed_audios = [], [], []
|
packed_images, packed_videos, packed_audios = [], [], []
|
||||||
|
if requires_packing_params:
|
||||||
|
sequence_boundaries = [0]
|
||||||
|
image_subseq_ids: list[int] = []
|
||||||
|
video_subseq_ids: list[int] = []
|
||||||
|
audio_subseq_ids: list[int] = []
|
||||||
|
|
||||||
for i, length in enumerate(knapsack):
|
for i, length in enumerate(knapsack):
|
||||||
index = length2indexes[length].pop()
|
index = length2indexes[length].pop()
|
||||||
packed_input_ids += batch_input_ids[index]
|
packed_input_ids += batch_input_ids[index]
|
||||||
@@ -174,6 +198,15 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
|||||||
packed_images += batch_images[index]
|
packed_images += batch_images[index]
|
||||||
packed_videos += batch_videos[index]
|
packed_videos += batch_videos[index]
|
||||||
packed_audios += batch_audios[index]
|
packed_audios += batch_audios[index]
|
||||||
|
if requires_packing_params:
|
||||||
|
n_img = len(batch_images[index])
|
||||||
|
n_vid = len(batch_videos[index])
|
||||||
|
n_aud = len(batch_audios[index])
|
||||||
|
sequence_boundaries.append(sequence_boundaries[-1] + len(batch_input_ids[index]))
|
||||||
|
image_subseq_ids.extend([i] * n_img)
|
||||||
|
video_subseq_ids.extend([i] * n_vid)
|
||||||
|
audio_subseq_ids.extend([i] * n_aud)
|
||||||
|
|
||||||
if self.data_args.neat_packing:
|
if self.data_args.neat_packing:
|
||||||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||||
else:
|
else:
|
||||||
@@ -189,10 +222,23 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
|||||||
else:
|
else:
|
||||||
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||||
|
|
||||||
|
if requires_packing_params:
|
||||||
|
sequence_boundaries.append(sequence_boundaries[-1] + pad_length)
|
||||||
|
|
||||||
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
|
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
|
||||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||||
|
|
||||||
model_inputs["input_ids"].append(packed_input_ids)
|
model_inputs["input_ids"].append(packed_input_ids)
|
||||||
|
if requires_packing_params:
|
||||||
|
packing_params = PackingParams(
|
||||||
|
sequence_boundaries=sequence_boundaries,
|
||||||
|
image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error
|
||||||
|
video_subseq_ids=video_subseq_ids or [MAX_SU_SEQ_IDX],
|
||||||
|
audio_subseq_ids=audio_subseq_ids or [MAX_SU_SEQ_IDX],
|
||||||
|
right_padding_length=pad_length,
|
||||||
|
)
|
||||||
|
model_inputs["packing_params"].append(asdict(packing_params))
|
||||||
|
|
||||||
model_inputs["attention_mask"].append(packed_attention_masks)
|
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||||
model_inputs["position_ids"].append(packed_position_ids)
|
model_inputs["position_ids"].append(packed_position_ids)
|
||||||
model_inputs["labels"].append(packed_labels)
|
model_inputs["labels"].append(packed_labels)
|
||||||
|
|||||||
@@ -1061,6 +1061,22 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# copied from glm4 template
|
||||||
|
register_template(
|
||||||
|
name="glm_ocr",
|
||||||
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||||
|
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||||
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
|
||||||
|
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||||
|
format_tools=ToolFormatter(tool_format="glm4"),
|
||||||
|
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||||
|
stop_words=["<|user|>", "<|observation|>"],
|
||||||
|
efficient_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# copied from glm4_moe template
|
# copied from glm4_moe template
|
||||||
register_template(
|
register_template(
|
||||||
name="glm4_7",
|
name="glm4_7",
|
||||||
@@ -1097,7 +1113,7 @@ register_template(
|
|||||||
register_template(
|
register_template(
|
||||||
name="gpt_oss",
|
name="gpt_oss",
|
||||||
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
|
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||||
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
||||||
default_system="You are ChatGPT, a large language model trained by OpenAI.",
|
default_system="You are ChatGPT, a large language model trained by OpenAI.",
|
||||||
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
|
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
|
||||||
@@ -2013,6 +2029,39 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="qwen3_5",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="qwen3_5"),
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||||
|
template_class=ReasoningTemplate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="qwen3_5_nothink",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="qwen3_5"),
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="sailor",
|
name="sailor",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
|
||||||
@@ -2202,3 +2251,24 @@ register_template(
|
|||||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||||
default_system="You are Zephyr, a helpful assistant.",
|
default_system="You are Zephyr, a helpful assistant.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# copied from glm4_7 template
|
||||||
|
register_template(
|
||||||
|
name="aeva",
|
||||||
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||||
|
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||||
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
|
||||||
|
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||||
|
format_tools=ToolFormatter(tool_format="glm4_moe"),
|
||||||
|
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||||
|
default_system=(
|
||||||
|
"You are an AI assistant named Aeva created by Zongzhi Lou. "
|
||||||
|
"Your answer should be friendly, unbiased, faithful, informative and detailed."
|
||||||
|
),
|
||||||
|
stop_words=["<|user|>", "<|observation|>"],
|
||||||
|
thought_words=("<think>", "</think>"),
|
||||||
|
efficient_eos=True,
|
||||||
|
template_class=Glm47ReasoningTemplate,
|
||||||
|
)
|
||||||
|
|||||||
@@ -85,6 +85,21 @@ QWEN_TOOL_PROMPT = (
|
|||||||
""""arguments": <args-json-object>}}\n</tool_call>"""
|
""""arguments": <args-json-object>}}\n</tool_call>"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
QWEN35_TOOL_PROMPT = (
|
||||||
|
"\n\n# Tools\n\nYou have access to the following functions:\n\n<tools>{tool_text}"
|
||||||
|
"\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n"
|
||||||
|
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n"
|
||||||
|
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n"
|
||||||
|
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n"
|
||||||
|
"- Function calls MUST follow the specified format: "
|
||||||
|
"an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n"
|
||||||
|
"- Required parameters MUST be specified\n"
|
||||||
|
"- You may provide optional reasoning for your function call in natural language "
|
||||||
|
"BEFORE the function call, but NOT after\n"
|
||||||
|
"- If there is no function call available, answer the question like normal with your current knowledge "
|
||||||
|
"and do not tell the user about function calls\n</IMPORTANT>"
|
||||||
|
)
|
||||||
|
|
||||||
SEED_TOOL_PROMPT = (
|
SEED_TOOL_PROMPT = (
|
||||||
"system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query."
|
"system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query."
|
||||||
"Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing "
|
"Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing "
|
||||||
@@ -453,6 +468,57 @@ class QwenToolUtils(ToolUtils):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen35ToolUtils(ToolUtils):
|
||||||
|
r"""Qwen 3.5 tool using template."""
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||||
|
tool_text = ""
|
||||||
|
for tool in tools:
|
||||||
|
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
|
||||||
|
tool_text += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||||
|
|
||||||
|
return QWEN35_TOOL_PROMPT.format(tool_text=tool_text)
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||||
|
function_texts = []
|
||||||
|
for func in functions:
|
||||||
|
name, arguments = func.name, json.loads(func.arguments)
|
||||||
|
prompt = f"<tool_call>\n<function={name}>"
|
||||||
|
for key, value in arguments.items():
|
||||||
|
prompt += f"\n<parameter={key}>"
|
||||||
|
if not isinstance(value, str):
|
||||||
|
value = json.dumps(value, ensure_ascii=False)
|
||||||
|
prompt += f"\n{value}\n</parameter>"
|
||||||
|
prompt += "\n</function>\n</tool_call>"
|
||||||
|
function_texts.append(prompt)
|
||||||
|
|
||||||
|
return "\n".join(function_texts)
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||||
|
results = []
|
||||||
|
regex = re.compile(r"<tool_call>\s*<function=\s*([^\s<>]+)\s*(.*?)\s*</function>\s*</tool_call>", re.DOTALL)
|
||||||
|
for func_name, params_block in re.findall(regex, content):
|
||||||
|
args_dict = {}
|
||||||
|
param_pattern = re.compile(r"<parameter=(.*?)>(.*?)</parameter>", re.DOTALL)
|
||||||
|
for key, raw_value in re.findall(param_pattern, params_block.strip()):
|
||||||
|
value = raw_value.strip()
|
||||||
|
try:
|
||||||
|
parsed_value = json.loads(value)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
parsed_value = raw_value.strip()
|
||||||
|
args_dict[key] = parsed_value
|
||||||
|
|
||||||
|
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
|
||||||
|
|
||||||
|
return results if results else content
|
||||||
|
|
||||||
|
|
||||||
class GLM4MOEToolUtils(QwenToolUtils):
|
class GLM4MOEToolUtils(QwenToolUtils):
|
||||||
r"""GLM-4-MOE tool using template."""
|
r"""GLM-4-MOE tool using template."""
|
||||||
|
|
||||||
@@ -662,6 +728,7 @@ TOOLS = {
|
|||||||
"minimax2": MiniMaxM2ToolUtils(),
|
"minimax2": MiniMaxM2ToolUtils(),
|
||||||
"mistral": MistralToolUtils(),
|
"mistral": MistralToolUtils(),
|
||||||
"qwen": QwenToolUtils(),
|
"qwen": QwenToolUtils(),
|
||||||
|
"qwen3_5": Qwen35ToolUtils(),
|
||||||
"glm4_moe": GLM4MOEToolUtils(),
|
"glm4_moe": GLM4MOEToolUtils(),
|
||||||
"seed_oss": SeedToolUtils(),
|
"seed_oss": SeedToolUtils(),
|
||||||
"ling": LingToolUtils(),
|
"ling": LingToolUtils(),
|
||||||
|
|||||||
@@ -65,15 +65,32 @@ MCA_SUPPORTED_MODELS = {
|
|||||||
"qwen2_vl",
|
"qwen2_vl",
|
||||||
"qwen2_5_vl",
|
"qwen2_5_vl",
|
||||||
"qwen3_vl",
|
"qwen3_vl",
|
||||||
|
"qwen3_vl_moe",
|
||||||
"qwen3",
|
"qwen3",
|
||||||
"qwen3_moe",
|
"qwen3_moe",
|
||||||
"qwen3_next",
|
"qwen3_next",
|
||||||
|
"qwen3_5",
|
||||||
|
"qwen3_5_moe",
|
||||||
}
|
}
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora", "oft"]
|
METHODS = ["full", "freeze", "lora", "oft"]
|
||||||
|
|
||||||
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
||||||
|
|
||||||
|
MROPE_MODELS = {
|
||||||
|
"glm4v",
|
||||||
|
"glm_ocr",
|
||||||
|
"Keye",
|
||||||
|
"qwen2_vl",
|
||||||
|
"qwen2_5_vl",
|
||||||
|
"qwen2_5_omni_thinker",
|
||||||
|
"qwen3_omni_moe_thinker",
|
||||||
|
"qwen3_vl",
|
||||||
|
"qwen3_vl_moe",
|
||||||
|
"qwen3_5",
|
||||||
|
"qwen3_5_moe",
|
||||||
|
}
|
||||||
|
|
||||||
MULTIMODAL_SUPPORTED_MODELS = set()
|
MULTIMODAL_SUPPORTED_MODELS = set()
|
||||||
|
|
||||||
PEFT_METHODS = {"lora", "oft"}
|
PEFT_METHODS = {"lora", "oft"}
|
||||||
@@ -950,6 +967,18 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"GLM-OCR": {
|
||||||
|
DownloadSource.DEFAULT: "zai-org/GLM-OCR",
|
||||||
|
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-OCR",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="glm_ocr",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"GLM-Z1-0414-9B-Chat": {
|
"GLM-Z1-0414-9B-Chat": {
|
||||||
@@ -2797,6 +2826,66 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Qwen3.5-0.8B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B-Base",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B-Base",
|
||||||
|
},
|
||||||
|
"Qwen3.5-2B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B-Base",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B-Base",
|
||||||
|
},
|
||||||
|
"Qwen3.5-4B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B-Base",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B-Base",
|
||||||
|
},
|
||||||
|
"Qwen3.5-9B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B-Base",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B-Base",
|
||||||
|
},
|
||||||
|
"Qwen3.5-35B-A3B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B-Base",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B-Base",
|
||||||
|
},
|
||||||
|
"Qwen3.5-0.8B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B",
|
||||||
|
},
|
||||||
|
"Qwen3.5-2B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B",
|
||||||
|
},
|
||||||
|
"Qwen3.5-4B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B",
|
||||||
|
},
|
||||||
|
"Qwen3.5-9B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B",
|
||||||
|
},
|
||||||
|
"Qwen3.5-27B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3.5-27B",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-27B",
|
||||||
|
},
|
||||||
|
"Qwen3.5-35B-A3B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B",
|
||||||
|
},
|
||||||
|
"Qwen3.5-122B-A10B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3.5-122B-A10B",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-122B-A10B",
|
||||||
|
},
|
||||||
|
"Qwen3.5-397B-A17B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen3.5-397B-A17B",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-397B-A17B",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="qwen3_5",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Qwen2-Audio-7B": {
|
"Qwen2-Audio-7B": {
|
||||||
@@ -3438,3 +3527,35 @@ register_model_group(
|
|||||||
},
|
},
|
||||||
template="zephyr",
|
template="zephyr",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Aeva-Flash-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "louzongzhi/Aeva-Flash",
|
||||||
|
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Flash",
|
||||||
|
DownloadSource.OPENMIND: "louzongzhi/Aeva-Flash",
|
||||||
|
},
|
||||||
|
"Aeva-Air-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "louzongzhi/Aeva-Air",
|
||||||
|
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Air",
|
||||||
|
DownloadSource.OPENMIND: "louzongzhi/Aeva-Air",
|
||||||
|
},
|
||||||
|
"Aeva-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "louzongzhi/Aeva",
|
||||||
|
DownloadSource.MODELSCOPE: "louzongktsi/Aeva",
|
||||||
|
DownloadSource.OPENMIND: "louzongzhi/Aeva",
|
||||||
|
},
|
||||||
|
"Aeva-Pro-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "louzongzhi/Aeva-Pro",
|
||||||
|
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Pro",
|
||||||
|
DownloadSource.OPENMIND: "louzongzhi/Aeva-Pro",
|
||||||
|
},
|
||||||
|
"Aeva-Max-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "louzongzhi/Aeva-Max",
|
||||||
|
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Max",
|
||||||
|
DownloadSource.OPENMIND: "louzongzhi/Aeva-Max",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="aeva",
|
||||||
|
)
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
|||||||
|
|
||||||
def check_dependencies() -> None:
|
def check_dependencies() -> None:
|
||||||
r"""Check the version of the required packages."""
|
r"""Check the version of the required packages."""
|
||||||
check_version("transformers>=4.51.0,<=5.0.0")
|
check_version("transformers>=4.55.0,<=5.2.0")
|
||||||
check_version("datasets>=2.16.0,<=4.0.0")
|
check_version("datasets>=2.16.0,<=4.0.0")
|
||||||
check_version("accelerate>=1.3.0,<=1.11.0")
|
check_version("accelerate>=1.3.0,<=1.11.0")
|
||||||
check_version("peft>=0.18.0,<=0.18.1")
|
check_version("peft>=0.18.0,<=0.18.1")
|
||||||
|
|||||||
@@ -490,6 +490,14 @@ class FinetuningArguments(
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to use the DFT loss."},
|
metadata={"help": "Whether to use the DFT loss."},
|
||||||
)
|
)
|
||||||
|
use_asft_loss: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to use the ASFT loss."},
|
||||||
|
)
|
||||||
|
asft_alpha: float = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "The alpha parameter for ASFT loss to control the power of adaptive weight."},
|
||||||
|
)
|
||||||
use_eaft_loss: bool = field(
|
use_eaft_loss: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to use the EAFT loss."},
|
metadata={"help": "Whether to use the EAFT loss."},
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
|
|||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import CHECKPOINT_NAMES, EngineName
|
from ..extras.constants import CHECKPOINT_NAMES, EngineName
|
||||||
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
|
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
|
||||||
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than
|
from ..extras.packages import is_mcore_adapter_available
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
@@ -100,6 +100,52 @@ def _parse_args(
|
|||||||
return tuple(parsed_args)
|
return tuple(parsed_args)
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_trackio_args(training_args: "TrainingArguments") -> None:
|
||||||
|
"""Validates Trackio-specific arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
training_args: TrainingArguments instance (not a dictionary)
|
||||||
|
"""
|
||||||
|
report_to = training_args.report_to
|
||||||
|
if not report_to:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(report_to, str):
|
||||||
|
report_to = [report_to]
|
||||||
|
|
||||||
|
if "trackio" not in report_to:
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- Enforce project (required by Trackio) ---
|
||||||
|
if not training_args.project:
|
||||||
|
raise ValueError("`--project` must be specified when using Trackio.")
|
||||||
|
|
||||||
|
# --- Validate trackio_space_id format ---
|
||||||
|
space_id = training_args.trackio_space_id
|
||||||
|
if space_id:
|
||||||
|
if space_id != "trackio" and "/" not in space_id:
|
||||||
|
logger.warning(
|
||||||
|
f"trackio_space_id '{space_id}' should typically be in format "
|
||||||
|
"'org/space' for Hugging Face Spaces deployment."
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Inform about default project usage ---
|
||||||
|
if training_args.project == "huggingface":
|
||||||
|
logger.info(
|
||||||
|
"Using default project name 'huggingface'. "
|
||||||
|
"Consider setting a custom project name with --project "
|
||||||
|
"for better organization."
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Validate hub repo privacy flag ---
|
||||||
|
if training_args.hub_private_repo:
|
||||||
|
logger.info("Repository will be created as private on Hugging Face Hub.")
|
||||||
|
|
||||||
|
# --- Recommend run_name for experiment clarity ---
|
||||||
|
if not training_args.run_name:
|
||||||
|
logger.warning("Consider setting --run_name for better experiment tracking clarity.")
|
||||||
|
|
||||||
|
|
||||||
def _set_transformers_logging() -> None:
|
def _set_transformers_logging() -> None:
|
||||||
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
|
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
@@ -278,8 +324,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
|||||||
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||||
raise ValueError("Unsloth does not support lora reward model.")
|
raise ValueError("Unsloth does not support lora reward model.")
|
||||||
|
|
||||||
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
|
if training_args.report_to and any(
|
||||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
logger not in ("wandb", "tensorboard", "trackio", "none") for logger in training_args.report_to
|
||||||
|
):
|
||||||
|
raise ValueError("PPO only accepts wandb, tensorboard, or trackio logger.")
|
||||||
|
|
||||||
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||||
@@ -346,12 +394,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
|||||||
if model_args.use_kt and is_deepspeed_zero3_enabled():
|
if model_args.use_kt and is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
|
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
|
||||||
|
|
||||||
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
|
|
||||||
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
|
|
||||||
|
|
||||||
_set_env_vars()
|
_set_env_vars()
|
||||||
_verify_model_args(model_args, data_args, finetuning_args)
|
_verify_model_args(model_args, data_args, finetuning_args)
|
||||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||||
|
_verify_trackio_args(training_args)
|
||||||
|
|
||||||
if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
||||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||||
@@ -421,7 +467,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
|||||||
training_args.resume_from_checkpoint is None
|
training_args.resume_from_checkpoint is None
|
||||||
and training_args.do_train
|
and training_args.do_train
|
||||||
and os.path.isdir(training_args.output_dir)
|
and os.path.isdir(training_args.output_dir)
|
||||||
and not training_args.overwrite_output_dir
|
and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
|
||||||
and can_resume_from_checkpoint
|
and can_resume_from_checkpoint
|
||||||
):
|
):
|
||||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||||
|
|||||||
@@ -77,6 +77,8 @@ def apply_liger_kernel(
|
|||||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
|
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
|
||||||
elif model_type == "qwen3_moe":
|
elif model_type == "qwen3_moe":
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
|
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
|
||||||
|
elif model_type == "qwen3_next":
|
||||||
|
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
|
||||||
elif model_type == "gpt_oss":
|
elif model_type == "gpt_oss":
|
||||||
try:
|
try:
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
|
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
|
||||||
|
|||||||
@@ -142,6 +142,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
|||||||
|
|
||||||
_set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock])
|
_set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock])
|
||||||
|
|
||||||
|
if model_type == "qwen3_next":
|
||||||
|
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
|
||||||
|
|
||||||
|
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
|
||||||
|
|
||||||
|
|
||||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
if not is_trainable or not model_args.moe_aux_loss_coef:
|
if not is_trainable or not model_args.moe_aux_loss_coef:
|
||||||
|
|||||||
@@ -37,7 +37,6 @@
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -45,10 +44,6 @@ import torch.nn.functional as F
|
|||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ...hparams import ModelArguments
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -105,13 +100,3 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "tor
|
|||||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||||
return indices, cu_seqlens, max_seqlen_in_batch
|
return indices, cu_seqlens, max_seqlen_in_batch
|
||||||
|
|
||||||
|
|
||||||
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
|
|
||||||
if not is_trainable or not model_args.block_diag_attn:
|
|
||||||
return
|
|
||||||
|
|
||||||
import transformers.modeling_flash_attention_utils
|
|
||||||
|
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
|
||||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import transformers.models
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
from ...extras.packages import is_transformers_version_greater_than
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -239,6 +238,15 @@ _register_composite_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_composite_model(
|
||||||
|
model_type="glm_ocr",
|
||||||
|
projector_key="visual.merger",
|
||||||
|
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||||
|
language_model_keys=["language_model", "lm_head"],
|
||||||
|
lora_conflict_keys=["patch_embed"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="internvl",
|
model_type="internvl",
|
||||||
)
|
)
|
||||||
@@ -335,9 +343,7 @@ _register_composite_model(
|
|||||||
model_type="qwen2_vl",
|
model_type="qwen2_vl",
|
||||||
projector_key="visual.merger",
|
projector_key="visual.merger",
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||||
language_model_keys=["language_model", "lm_head"]
|
language_model_keys=["language_model", "lm_head"],
|
||||||
if is_transformers_version_greater_than("4.52.0")
|
|
||||||
else ["model", "lm_head"],
|
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -346,9 +352,7 @@ _register_composite_model(
|
|||||||
model_type="qwen2_5_vl",
|
model_type="qwen2_5_vl",
|
||||||
projector_key="visual.merger",
|
projector_key="visual.merger",
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||||
language_model_keys=["language_model", "lm_head"]
|
language_model_keys=["language_model", "lm_head"],
|
||||||
if is_transformers_version_greater_than("4.52.0")
|
|
||||||
else ["model", "lm_head"],
|
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -381,7 +385,25 @@ _register_composite_model(
|
|||||||
"visual.deepstack_merger_list",
|
"visual.deepstack_merger_list",
|
||||||
"audio_tower",
|
"audio_tower",
|
||||||
],
|
],
|
||||||
language_model_keys=["model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
|
lora_conflict_keys=["patch_embed"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_composite_model(
|
||||||
|
model_type="qwen3_5",
|
||||||
|
projector_key="model.visual.merger",
|
||||||
|
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||||
|
language_model_keys=["language_model", "lm_head"],
|
||||||
|
lora_conflict_keys=["patch_embed"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_composite_model(
|
||||||
|
model_type="qwen3_5_moe",
|
||||||
|
projector_key="model.visual.merger",
|
||||||
|
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||||
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ from .model_utils.embedding import resize_embedding_layer
|
|||||||
from .model_utils.kv_cache import configure_kv_cache
|
from .model_utils.kv_cache import configure_kv_cache
|
||||||
from .model_utils.longlora import configure_longlora
|
from .model_utils.longlora import configure_longlora
|
||||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||||
from .model_utils.packing import configure_packing
|
|
||||||
from .model_utils.quantization import configure_quantization
|
from .model_utils.quantization import configure_quantization
|
||||||
from .model_utils.rope import configure_rope
|
from .model_utils.rope import configure_rope
|
||||||
from .model_utils.valuehead import prepare_valuehead_model
|
from .model_utils.valuehead import prepare_valuehead_model
|
||||||
@@ -142,7 +141,6 @@ def patch_config(
|
|||||||
configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs)
|
configure_quantization(config, tokenizer, model_args, is_trainable, init_kwargs)
|
||||||
configure_moe(config, model_args, is_trainable)
|
configure_moe(config, model_args, is_trainable)
|
||||||
configure_visual_model(config)
|
configure_visual_model(config)
|
||||||
configure_packing(model_args, is_trainable)
|
|
||||||
configure_kv_cache(config, model_args, is_trainable)
|
configure_kv_cache(config, model_args, is_trainable)
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "qwen":
|
if getattr(config, "model_type", None) == "qwen":
|
||||||
|
|||||||
@@ -228,7 +228,7 @@ class LogCallback(TrainerCallback):
|
|||||||
if (
|
if (
|
||||||
args.should_save
|
args.should_save
|
||||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||||
and args.overwrite_output_dir
|
and getattr(args, "overwrite_output_dir", False)
|
||||||
):
|
):
|
||||||
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
|
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
|
||||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||||
@@ -371,6 +371,18 @@ class ReporterCallback(TrainerCallback):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "trackio" in args.report_to:
|
||||||
|
import trackio
|
||||||
|
|
||||||
|
trackio.config.update(
|
||||||
|
{
|
||||||
|
"model_args": self.model_args.to_dict(),
|
||||||
|
"data_args": self.data_args.to_dict(),
|
||||||
|
"finetuning_args": self.finetuning_args.to_dict(),
|
||||||
|
"generating_args": self.generating_args.to_dict(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if self.finetuning_args.use_swanlab:
|
if self.finetuning_args.use_swanlab:
|
||||||
import swanlab # type: ignore
|
import swanlab # type: ignore
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import json
|
||||||
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
@@ -77,12 +79,43 @@ def _data_collator_wrapper(data_collator: Any):
|
|||||||
|
|
||||||
def _check_model_support(model_args: "ModelArguments"):
|
def _check_model_support(model_args: "ModelArguments"):
|
||||||
from transformers import AutoConfig as HfAutoConfig
|
from transformers import AutoConfig as HfAutoConfig
|
||||||
|
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
|
||||||
|
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
|
||||||
|
model_type = mca_config.get("hf_model_type", None)
|
||||||
|
else:
|
||||||
|
config = HfAutoConfig.from_pretrained(
|
||||||
|
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||||
|
)
|
||||||
|
model_type = config.model_type
|
||||||
|
|
||||||
config = HfAutoConfig.from_pretrained(
|
if model_type not in MCA_SUPPORTED_MODELS:
|
||||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
raise ValueError(
|
||||||
)
|
f"Model {model_type} is not supported by mcore_adapter."
|
||||||
if config.model_type not in MCA_SUPPORTED_MODELS:
|
"You can try to upgrade mcore_adapter to the latest version for more supported models."
|
||||||
raise ValueError(f"Model {config.model_type} is not supported by MCA.")
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
|
||||||
|
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
|
||||||
|
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
params_to_freeze = []
|
||||||
|
if finetuning_args.freeze_vision_tower:
|
||||||
|
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
||||||
|
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]:
|
||||||
|
params_to_freeze.extend(["vision_model.pos_embed"])
|
||||||
|
|
||||||
|
if finetuning_args.freeze_multi_modal_projector:
|
||||||
|
params_to_freeze.extend(["multi_modal_projector"])
|
||||||
|
|
||||||
|
if finetuning_args.freeze_language_model:
|
||||||
|
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
|
||||||
|
|
||||||
|
if params_to_freeze:
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if any(name.startswith(k) for k in params_to_freeze):
|
||||||
|
p.requires_grad_(False)
|
||||||
|
|
||||||
|
|
||||||
def run_pt(
|
def run_pt(
|
||||||
@@ -161,22 +194,8 @@ def run_sft(
|
|||||||
_check_model_support(model_args)
|
_check_model_support(model_args)
|
||||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||||
|
|
||||||
# optional freezing for qwen2_vl, qwen2_5_vl
|
# optional freezing for qwen_vl series
|
||||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
|
_freeze_model_parameters(model, finetuning_args)
|
||||||
params_to_freeze = []
|
|
||||||
if finetuning_args.freeze_vision_tower:
|
|
||||||
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
|
||||||
|
|
||||||
if finetuning_args.freeze_multi_modal_projector:
|
|
||||||
params_to_freeze.extend(["multi_modal_projector"])
|
|
||||||
|
|
||||||
if finetuning_args.freeze_language_model:
|
|
||||||
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
|
|
||||||
|
|
||||||
if params_to_freeze:
|
|
||||||
for name, p in model.named_parameters():
|
|
||||||
if any(name.startswith(k) for k in params_to_freeze):
|
|
||||||
p.requires_grad_(False)
|
|
||||||
|
|
||||||
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||||
@@ -229,6 +248,8 @@ def run_dpo(
|
|||||||
_check_model_support(model_args)
|
_check_model_support(model_args)
|
||||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||||
|
|
||||||
|
_freeze_model_parameters(model, finetuning_args)
|
||||||
|
|
||||||
if finetuning_args.use_ref_model:
|
if finetuning_args.use_ref_model:
|
||||||
ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args)
|
ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args)
|
||||||
ref_model = AutoModel.from_config(ref_config)
|
ref_model = AutoModel.from_config(ref_config)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
@@ -52,6 +53,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
model_args: Optional["ModelArguments"] = None,
|
model_args: Optional["ModelArguments"] = None,
|
||||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||||
|
ref_model: Optional["torch.nn.Module"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||||
@@ -82,6 +84,27 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
|
self.ref_model = ref_model
|
||||||
|
|
||||||
|
if ref_model is not None:
|
||||||
|
from trl.models.utils import prepare_deepspeed, prepare_fsdp
|
||||||
|
|
||||||
|
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
|
||||||
|
if not (
|
||||||
|
getattr(ref_model, "is_loaded_in_8bit", False) or getattr(ref_model, "is_loaded_in_4bit", False)
|
||||||
|
): # quantized models are already set on the correct device
|
||||||
|
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||||
|
elif getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
|
||||||
|
if self.accelerator.is_fsdp2:
|
||||||
|
from accelerate.utils.fsdp_utils import fsdp2_prepare_model
|
||||||
|
|
||||||
|
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model)
|
||||||
|
else:
|
||||||
|
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
||||||
|
else:
|
||||||
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
self.ref_model.eval()
|
||||||
|
|
||||||
if finetuning_args.use_dft_loss:
|
if finetuning_args.use_dft_loss:
|
||||||
from ..trainer_utils import dft_loss_func
|
from ..trainer_utils import dft_loss_func
|
||||||
|
|
||||||
@@ -93,6 +116,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
|
self.compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
|
||||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||||
)
|
)
|
||||||
|
elif finetuning_args.use_asft_loss:
|
||||||
|
from ..trainer_utils import asft_loss_func
|
||||||
|
|
||||||
|
self.compute_loss_func = partial(
|
||||||
|
asft_loss_func,
|
||||||
|
asft_alpha=finetuning_args.asft_alpha,
|
||||||
|
)
|
||||||
|
|
||||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||||
verify_fp8_status(self.accelerator, training_args)
|
verify_fp8_status(self.accelerator, training_args)
|
||||||
@@ -119,7 +149,17 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_loss(self, model, inputs, *args, **kwargs):
|
def compute_loss(self, model, inputs, *args, **kwargs):
|
||||||
return super().compute_loss(model, inputs, *args, **kwargs)
|
if self.finetuning_args.use_asft_loss:
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_outputs = self.ref_model(
|
||||||
|
input_ids=inputs["input_ids"],
|
||||||
|
attention_mask=inputs.get("attention_mask", None),
|
||||||
|
)
|
||||||
|
ref_logits = ref_outputs.logits
|
||||||
|
outputs = model(**inputs)
|
||||||
|
return self.compute_loss_func(outputs, inputs["labels"], ref_logits)
|
||||||
|
else:
|
||||||
|
return super().compute_loss(model, inputs, *args, **kwargs)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
@@ -175,7 +215,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
if len(pad_len): # move pad token to last
|
if len(pad_len): # move pad token to last
|
||||||
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
||||||
|
|
||||||
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
|
input_ids_column = dataset["input_ids"]
|
||||||
|
try:
|
||||||
|
input_ids_list = input_ids_column.to_pylist()
|
||||||
|
except AttributeError:
|
||||||
|
input_ids_list = list(input_ids_column)
|
||||||
|
|
||||||
|
decoded_inputs = self.processing_class.batch_decode(input_ids_list, skip_special_tokens=False)
|
||||||
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
|
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
|
||||||
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
|
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from ...extras.misc import calculate_tps
|
|||||||
from ...extras.packages import is_transformers_version_greater_than
|
from ...extras.packages import is_transformers_version_greater_than
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..trainer_utils import create_modelcard_and_push
|
from ..trainer_utils import create_modelcard_and_push, create_ref_model
|
||||||
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
||||||
from .trainer import CustomSeq2SeqTrainer
|
from .trainer import CustomSeq2SeqTrainer
|
||||||
|
|
||||||
@@ -52,6 +52,10 @@ def run_sft(
|
|||||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
|
||||||
|
ref_model = None
|
||||||
|
if finetuning_args.use_asft_loss:
|
||||||
|
ref_model = create_ref_model(model_args, finetuning_args)
|
||||||
|
|
||||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||||
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||||
|
|
||||||
@@ -61,6 +65,7 @@ def run_sft(
|
|||||||
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
|
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||||
block_diag_attn=model_args.block_diag_attn,
|
block_diag_attn=model_args.block_diag_attn,
|
||||||
|
neat_packing=data_args.neat_packing,
|
||||||
attn_implementation=getattr(model.config, "_attn_implementation", None),
|
attn_implementation=getattr(model.config, "_attn_implementation", None),
|
||||||
compute_dtype=model_args.compute_dtype,
|
compute_dtype=model_args.compute_dtype,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
@@ -124,6 +129,7 @@ def run_sft(
|
|||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
gen_kwargs=gen_kwargs,
|
gen_kwargs=gen_kwargs,
|
||||||
|
ref_model=ref_model,
|
||||||
**dataset_module,
|
**dataset_module,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
**metric_module,
|
**metric_module,
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from collections.abc import Callable, Mapping
|
|||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
@@ -51,6 +52,7 @@ if is_ray_available():
|
|||||||
import ray
|
import ray
|
||||||
from ray.util.placement_group import PlacementGroup, placement_group
|
from ray.util.placement_group import PlacementGroup, placement_group
|
||||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
|
from ray.util.state import list_nodes
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -681,6 +683,88 @@ def _dft_cross_entropy(
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def asft_loss_func(
|
||||||
|
outputs,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
ref_logits: torch.Tensor,
|
||||||
|
asft_alpha: float = 0.1,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
logits = outputs.get("logits")
|
||||||
|
if logits is None:
|
||||||
|
return outputs.get("loss", torch.tensor(0.0))
|
||||||
|
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
# shift for causal LM
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
shift_ref_logits = ref_logits[..., :-1, :].contiguous()
|
||||||
|
|
||||||
|
vocab_size = shift_logits.size(-1)
|
||||||
|
|
||||||
|
# flatten
|
||||||
|
shift_logits = shift_logits.view(-1, vocab_size)
|
||||||
|
shift_ref_logits = shift_ref_logits.view(-1, vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||||
|
|
||||||
|
return _asft_cross_entropy(
|
||||||
|
policy_logits=shift_logits,
|
||||||
|
policy_labels=shift_labels,
|
||||||
|
ref_logits=shift_ref_logits,
|
||||||
|
asft_alpha=asft_alpha,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _asft_cross_entropy(
|
||||||
|
policy_logits: torch.Tensor,
|
||||||
|
policy_labels: torch.Tensor,
|
||||||
|
ref_logits: torch.Tensor,
|
||||||
|
asft_alpha: float = 0.1,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
dft_loss = _dft_cross_entropy(
|
||||||
|
policy_logits,
|
||||||
|
policy_labels,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
kl_loss = _kl_divergence(
|
||||||
|
policy_logits,
|
||||||
|
ref_logits,
|
||||||
|
policy_labels,
|
||||||
|
ignore_index=ignore_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dft_loss + asft_alpha * kl_loss
|
||||||
|
|
||||||
|
|
||||||
|
def _kl_divergence(
|
||||||
|
policy_logits: torch.Tensor,
|
||||||
|
ref_logits: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# log p(y|x)
|
||||||
|
log_p = F.log_softmax(policy_logits, dim=-1)
|
||||||
|
|
||||||
|
# q(y|x)
|
||||||
|
q = F.softmax(ref_logits, dim=-1)
|
||||||
|
|
||||||
|
# token-wise KL
|
||||||
|
kl = F.kl_div(
|
||||||
|
log_p,
|
||||||
|
q,
|
||||||
|
reduction="none",
|
||||||
|
).sum(dim=-1) # [N]
|
||||||
|
|
||||||
|
# mask padding tokens
|
||||||
|
mask = (labels != ignore_index).float()
|
||||||
|
|
||||||
|
return (kl * mask).sum() / mask.sum()
|
||||||
|
|
||||||
|
|
||||||
def eaft_loss_func(
|
def eaft_loss_func(
|
||||||
outputs: "torch.Tensor",
|
outputs: "torch.Tensor",
|
||||||
labels: "torch.Tensor",
|
labels: "torch.Tensor",
|
||||||
@@ -858,7 +942,7 @@ def get_ray_remote_config_for_worker(
|
|||||||
|
|
||||||
def get_ray_head_node_ip() -> str:
|
def get_ray_head_node_ip() -> str:
|
||||||
r"""Get the IP address of the Ray head node."""
|
r"""Get the IP address of the Ray head node."""
|
||||||
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
|
head_ip = next(node["node_ip"] for node in list_nodes() if node.get("is_head_node", False))
|
||||||
return head_ip
|
return head_ip
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
|
|||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
||||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than
|
||||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||||
@@ -160,17 +160,28 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
|
|||||||
model = model.to(output_dtype)
|
model = model.to(output_dtype)
|
||||||
logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
|
logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
|
||||||
|
|
||||||
model.save_pretrained(
|
# Prepare save arguments (safe_serialization removed in transformers v5.0.0)
|
||||||
save_directory=model_args.export_dir,
|
save_kwargs = {
|
||||||
max_shard_size=f"{model_args.export_size}GB",
|
"save_directory": model_args.export_dir,
|
||||||
safe_serialization=(not model_args.export_legacy_format),
|
"max_shard_size": f"{model_args.export_size}GB",
|
||||||
)
|
}
|
||||||
|
if not is_transformers_version_greater_than("5.0.0"):
|
||||||
|
save_kwargs["safe_serialization"] = not model_args.export_legacy_format
|
||||||
|
|
||||||
|
model.save_pretrained(**save_kwargs)
|
||||||
|
|
||||||
if model_args.export_hub_model_id is not None:
|
if model_args.export_hub_model_id is not None:
|
||||||
|
# Prepare push arguments (safe_serialization removed in transformers v5.0.0)
|
||||||
|
push_kwargs = {
|
||||||
|
"max_shard_size": f"{model_args.export_size}GB",
|
||||||
|
}
|
||||||
|
if not is_transformers_version_greater_than("5.0.0"):
|
||||||
|
push_kwargs["safe_serialization"] = not model_args.export_legacy_format
|
||||||
|
|
||||||
model.push_to_hub(
|
model.push_to_hub(
|
||||||
model_args.export_hub_model_id,
|
model_args.export_hub_model_id,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
max_shard_size=f"{model_args.export_size}GB",
|
**push_kwargs,
|
||||||
safe_serialization=(not model_args.export_legacy_format),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if finetuning_args.stage == "rm":
|
if finetuning_args.stage == "rm":
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from omegaconf import OmegaConf
|
|||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
from ..utils.env import is_env_enabled
|
from ..utils.env import is_env_enabled
|
||||||
|
from ..utils.helper import set_seed
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
from .sample_args import SampleArguments
|
from .sample_args import SampleArguments
|
||||||
@@ -56,6 +57,14 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
|
|||||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||||
|
|
||||||
|
# Seed as early as possible after argument parsing so all downstream
|
||||||
|
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
|
||||||
|
for arg in parsed_args:
|
||||||
|
seed = getattr(arg, "seed", None)
|
||||||
|
if seed is not None:
|
||||||
|
set_seed(seed)
|
||||||
|
break
|
||||||
|
|
||||||
return tuple(parsed_args)
|
return tuple(parsed_args)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class TrainingArguments:
|
|||||||
metadata={"help": "Number of workers for batching."},
|
metadata={"help": "Number of workers for batching."},
|
||||||
)
|
)
|
||||||
enable_activation_checkpointing: bool = field(
|
enable_activation_checkpointing: bool = field(
|
||||||
default=True,
|
default=False,
|
||||||
metadata={"help": "Enable activation checkpointing for training."},
|
metadata={"help": "Enable activation checkpointing for training."},
|
||||||
)
|
)
|
||||||
dist_config: PluginConfig | None = field(
|
dist_config: PluginConfig | None = field(
|
||||||
@@ -81,6 +81,10 @@ class TrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Learning rate scheduler configuration for training."},
|
metadata={"help": "Learning rate scheduler configuration for training."},
|
||||||
)
|
)
|
||||||
|
seed: int = field(
|
||||||
|
default=42,
|
||||||
|
metadata={"help": "Random seed that will be set at the beginning of training."},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self.dist_config = get_plugin_config(self.dist_config)
|
self.dist_config = get_plugin_config(self.dist_config)
|
||||||
|
|||||||
@@ -76,19 +76,28 @@ class BaseTrainer:
|
|||||||
if self.args.enable_activation_checkpointing:
|
if self.args.enable_activation_checkpointing:
|
||||||
self.model.gradient_checkpointing_enable({"use_reentrant": False})
|
self.model.gradient_checkpointing_enable({"use_reentrant": False})
|
||||||
|
|
||||||
if self.args.dist_config is not None:
|
self._deepspeed_engine = None
|
||||||
shard_need_optimizer = self.args.dist_config.name == "deepspeed"
|
dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
|
||||||
else:
|
|
||||||
shard_need_optimizer = False
|
|
||||||
|
|
||||||
if shard_need_optimizer:
|
if dist_name == "deepspeed":
|
||||||
|
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
|
||||||
|
|
||||||
|
self._deepspeed_engine = DistributedPlugin("deepspeed")(
|
||||||
|
self.model,
|
||||||
|
self.args.dist_config,
|
||||||
|
num_micro_batch=self.train_batch_generator.num_micro_batch,
|
||||||
|
micro_batch_size=self.args.micro_batch_size,
|
||||||
|
)
|
||||||
self._init_optimizer()
|
self._init_optimizer()
|
||||||
self._shard_model()
|
self._init_lr_scheduler()
|
||||||
|
self.model, self.optimizer, self.lr_scheduler = self._deepspeed_engine.prepare(
|
||||||
|
self.model, self.optimizer, self.lr_scheduler
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
# fsdp2 / DDP / no dist
|
||||||
self._shard_model()
|
self._shard_model()
|
||||||
self._init_optimizer()
|
self._init_optimizer()
|
||||||
|
self._init_lr_scheduler()
|
||||||
self._init_lr_scheduler()
|
|
||||||
|
|
||||||
def _create_batch_generator(self) -> None:
|
def _create_batch_generator(self) -> None:
|
||||||
self.train_batch_generator = BatchGenerator(
|
self.train_batch_generator = BatchGenerator(
|
||||||
@@ -99,6 +108,7 @@ class BaseTrainer:
|
|||||||
cutoff_len=self.args.cutoff_len,
|
cutoff_len=self.args.cutoff_len,
|
||||||
batching_workers=self.args.batching_workers,
|
batching_workers=self.args.batching_workers,
|
||||||
batching_strategy=self.args.batching_strategy,
|
batching_strategy=self.args.batching_strategy,
|
||||||
|
seed=self.args.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _shard_model(self) -> None:
|
def _shard_model(self) -> None:
|
||||||
@@ -171,25 +181,35 @@ class BaseTrainer:
|
|||||||
step_loss = 0
|
step_loss = 0
|
||||||
step_valid_tokens = compute_valid_tokens(micro_batches)
|
step_valid_tokens = compute_valid_tokens(micro_batches)
|
||||||
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
||||||
for micro_batch in micro_batches:
|
num_micro = len(micro_batches)
|
||||||
|
for i, micro_batch in enumerate(micro_batches):
|
||||||
loss = self.compute_loss(micro_batch)
|
loss = self.compute_loss(micro_batch)
|
||||||
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
|
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
|
||||||
# fsdp uses mean reduction so we need to scale the loss by dp_size
|
# fsdp uses mean reduction so we need to scale the loss by dp_size
|
||||||
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
|
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
|
||||||
|
|
||||||
loss.backward()
|
if self._deepspeed_engine is not None:
|
||||||
|
# deepspeed: set sync_gradients so engine.step() only fires on last micro-batch
|
||||||
|
self._deepspeed_engine.accelerator.sync_gradients = i == num_micro - 1
|
||||||
|
self._deepspeed_engine.backward(loss)
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
step_loss += loss.item()
|
step_loss += loss.item()
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
if self._deepspeed_engine is not None:
|
||||||
|
# deepspeed: engine.step() already ran inside backward at the sync boundary
|
||||||
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
grad_norm = self._deepspeed_engine.get_grad_norm()
|
||||||
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
|
|
||||||
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
|
|
||||||
else:
|
else:
|
||||||
self.optimizer.step()
|
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
||||||
|
|
||||||
self.lr_scheduler.step()
|
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
||||||
self.optimizer.zero_grad()
|
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
|
||||||
|
logger.warning_rank0(f"Gradient norm is not finite: {grad_norm}")
|
||||||
|
else:
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
||||||
DistributedInterface().sync()
|
DistributedInterface().sync()
|
||||||
@@ -203,7 +223,14 @@ class BaseTrainer:
|
|||||||
|
|
||||||
def save_model(self) -> None:
|
def save_model(self) -> None:
|
||||||
"""Save the model."""
|
"""Save the model."""
|
||||||
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
|
||||||
model_to_save.save_pretrained(self.args.output_dir)
|
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
|
||||||
self.renderer.processor.save_pretrained(self.args.output_dir)
|
|
||||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
DistributedPlugin(self.args.dist_config.name).save_model(
|
||||||
|
self.model, self.args.output_dir, self.renderer.processor
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||||
|
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||||
|
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||||
|
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||||
|
|||||||
@@ -90,6 +90,26 @@ class ModelEngine:
|
|||||||
Transformers can choose the proper model init context.
|
Transformers can choose the proper model init context.
|
||||||
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
|
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
|
||||||
"""
|
"""
|
||||||
|
if self.args.init_config is not None:
|
||||||
|
from ..plugins.model_plugins.initialization import InitPlugin
|
||||||
|
|
||||||
|
init_device = InitPlugin(self.args.init_config.name)()
|
||||||
|
else:
|
||||||
|
init_device = DistributedInterface().current_device
|
||||||
|
|
||||||
|
init_kwargs = {"device_map": init_device}
|
||||||
|
|
||||||
|
if self.args.quant_config is not None:
|
||||||
|
from ..plugins.model_plugins.quantization import QuantizationPlugin
|
||||||
|
|
||||||
|
init_kwargs = QuantizationPlugin(self.args.quant_config.name)(
|
||||||
|
init_kwargs=init_kwargs,
|
||||||
|
config=self.model_config,
|
||||||
|
tokenizer=self.processor,
|
||||||
|
model_args=self.args,
|
||||||
|
is_trainable=self.is_train,
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.model_class == ModelClass.LLM:
|
if self.args.model_class == ModelClass.LLM:
|
||||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||||
|
|
||||||
@@ -107,14 +127,8 @@ class ModelEngine:
|
|||||||
|
|
||||||
AutoClass = AutoModel
|
AutoClass = AutoModel
|
||||||
|
|
||||||
if self.args.init_config is not None:
|
|
||||||
from ..plugins.model_plugins.initialization import InitPlugin
|
|
||||||
|
|
||||||
init_device = InitPlugin(self.args.init_config.name)()
|
|
||||||
else:
|
|
||||||
init_device = DistributedInterface().current_device
|
|
||||||
|
|
||||||
if init_device.type == DeviceType.META:
|
if init_device.type == DeviceType.META:
|
||||||
|
assert self.args.quant_config is None, "Quantization is not supported with meta device."
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoClass.from_config(self.model_config)
|
model = AutoClass.from_config(self.model_config)
|
||||||
else:
|
else:
|
||||||
@@ -122,8 +136,8 @@ class ModelEngine:
|
|||||||
self.args.model,
|
self.args.model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
dtype="auto",
|
dtype="auto",
|
||||||
device_map=init_device,
|
|
||||||
trust_remote_code=self.args.trust_remote_code,
|
trust_remote_code=self.args.trust_remote_code,
|
||||||
|
**init_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.peft_config is None:
|
if self.args.peft_config is None:
|
||||||
|
|||||||
@@ -26,6 +26,7 @@
|
|||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.utils.data import default_collate
|
from torch.utils.data import default_collate
|
||||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||||
@@ -71,6 +72,7 @@ class BatchGenerator(Iterator):
|
|||||||
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
|
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
|
||||||
pin_memory: bool = True,
|
pin_memory: bool = True,
|
||||||
drop_last: bool = True,
|
drop_last: bool = True,
|
||||||
|
seed: int = 42,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.renderer = renderer
|
self.renderer = renderer
|
||||||
@@ -82,6 +84,7 @@ class BatchGenerator(Iterator):
|
|||||||
self.batching_strategy = batching_strategy
|
self.batching_strategy = batching_strategy
|
||||||
self.pin_memory = pin_memory
|
self.pin_memory = pin_memory
|
||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
|
self.seed = seed
|
||||||
# TODO: support length and infinity
|
# TODO: support length and infinity
|
||||||
dp_size = DistributedInterface().get_world_size(Dim.DP)
|
dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||||
|
|
||||||
@@ -128,12 +131,15 @@ class BatchGenerator(Iterator):
|
|||||||
num_replicas=DistributedInterface().get_world_size(Dim.DP),
|
num_replicas=DistributedInterface().get_world_size(Dim.DP),
|
||||||
rank=DistributedInterface().get_rank(Dim.DP),
|
rank=DistributedInterface().get_rank(Dim.DP),
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
seed=0,
|
seed=self.seed,
|
||||||
drop_last=self.drop_last,
|
drop_last=self.drop_last,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||||
|
|
||||||
|
generato_seed = torch.Generator()
|
||||||
|
generato_seed.manual_seed(self.seed)
|
||||||
|
|
||||||
self._data_provider = StatefulDataLoader(
|
self._data_provider = StatefulDataLoader(
|
||||||
self.dataset,
|
self.dataset,
|
||||||
batch_size=self.micro_batch_size * self.num_micro_batch,
|
batch_size=self.micro_batch_size * self.num_micro_batch,
|
||||||
@@ -143,6 +149,7 @@ class BatchGenerator(Iterator):
|
|||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
pin_memory_device=DistributedInterface().current_device.type,
|
pin_memory_device=DistributedInterface().current_device.type,
|
||||||
drop_last=self.drop_last,
|
drop_last=self.drop_last,
|
||||||
|
generator=generato_seed,
|
||||||
)
|
)
|
||||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||||
self._length = len(self._data_provider)
|
self._length = len(self._data_provider)
|
||||||
|
|||||||
@@ -91,7 +91,11 @@ class Renderer:
|
|||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
|
||||||
def render_messages(
|
def render_messages(
|
||||||
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: str | None = None,
|
||||||
|
is_generate: bool = False,
|
||||||
|
enable_thinking: bool = False,
|
||||||
) -> ModelInput:
|
) -> ModelInput:
|
||||||
"""Apply template to messages and convert them to model input.
|
"""Apply template to messages and convert them to model input.
|
||||||
|
|
||||||
@@ -99,6 +103,7 @@ class Renderer:
|
|||||||
messages (list[Message]): The messages to render.
|
messages (list[Message]): The messages to render.
|
||||||
tools (str | None, optional): The tools to use. Defaults to None.
|
tools (str | None, optional): The tools to use. Defaults to None.
|
||||||
is_generate (bool, optional): Whether to render for generation. Defaults to False.
|
is_generate (bool, optional): Whether to render for generation. Defaults to False.
|
||||||
|
enable_thinking (bool, optional): Whether to enable thinking mode for generation. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelInput: The rendered model input.
|
ModelInput: The rendered model input.
|
||||||
@@ -108,7 +113,9 @@ class Renderer:
|
|||||||
else:
|
else:
|
||||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||||
|
|
||||||
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
|
return RenderingPlugin(self.template).render_messages(
|
||||||
|
self.processor, messages, tools, is_generate, enable_thinking
|
||||||
|
)
|
||||||
|
|
||||||
def parse_message(self, generated_text: str) -> Message:
|
def parse_message(self, generated_text: str) -> Message:
|
||||||
"""Parse a message in the template format.
|
"""Parse a message in the template format.
|
||||||
|
|||||||
@@ -125,6 +125,11 @@ def launch():
|
|||||||
|
|
||||||
run_chat()
|
run_chat()
|
||||||
|
|
||||||
|
elif command == "merge":
|
||||||
|
from llamafactory.v1.plugins.model_plugins.peft import merge_and_export_model
|
||||||
|
|
||||||
|
merge_and_export_model()
|
||||||
|
|
||||||
elif command == "env":
|
elif command == "env":
|
||||||
raise NotImplementedError("Environment information is not implemented yet.")
|
raise NotImplementedError("Environment information is not implemented yet.")
|
||||||
|
|
||||||
|
|||||||
@@ -12,14 +12,22 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Literal, TypedDict
|
import re
|
||||||
|
from typing import Literal, TypedDict, Union
|
||||||
|
|
||||||
from peft import LoraConfig, PeftModel, get_peft_model
|
import torch
|
||||||
|
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
||||||
|
|
||||||
|
from ...config import InputArgument, get_args
|
||||||
|
from ...core.model_engine import ModelEngine
|
||||||
|
from ...utils import logging
|
||||||
from ...utils.plugin import BasePlugin
|
from ...utils.plugin import BasePlugin
|
||||||
from ...utils.types import HFModel
|
from ...utils.types import HFModel
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LoraConfigDict(TypedDict, total=False):
|
class LoraConfigDict(TypedDict, total=False):
|
||||||
name: Literal["lora"]
|
name: Literal["lora"]
|
||||||
"""Plugin name."""
|
"""Plugin name."""
|
||||||
@@ -27,8 +35,28 @@ class LoraConfigDict(TypedDict, total=False):
|
|||||||
"""Lora rank."""
|
"""Lora rank."""
|
||||||
lora_alpha: int
|
lora_alpha: int
|
||||||
"""Lora alpha."""
|
"""Lora alpha."""
|
||||||
target_modules: list[str]
|
lora_dropout: float
|
||||||
|
"""Lora dropout."""
|
||||||
|
target_modules: Union[list[str], str]
|
||||||
"""Target modules."""
|
"""Target modules."""
|
||||||
|
use_rslora: bool
|
||||||
|
"""Use RS-LoRA."""
|
||||||
|
use_dora: bool
|
||||||
|
"""Use DoRA."""
|
||||||
|
modules_to_save: list[str]
|
||||||
|
"""Modules to save."""
|
||||||
|
adapter_name_or_path: Union[list[str], str]
|
||||||
|
"""Path to the adapter(s)."""
|
||||||
|
export_dir: str
|
||||||
|
"""Path to the export directory."""
|
||||||
|
export_size: int
|
||||||
|
"""Shard size for the export model."""
|
||||||
|
export_hub_model_id: str
|
||||||
|
"""Hub model ID for the export model."""
|
||||||
|
infer_dtype: Literal["auto", "float16", "float32", "bfloat16"]
|
||||||
|
"""Inference data type for the export model."""
|
||||||
|
export_legacy_format: bool
|
||||||
|
"""Use legacy format for the export model."""
|
||||||
|
|
||||||
|
|
||||||
class FreezeConfigDict(TypedDict, total=False):
|
class FreezeConfigDict(TypedDict, total=False):
|
||||||
@@ -36,22 +64,283 @@ class FreezeConfigDict(TypedDict, total=False):
|
|||||||
"""Plugin name."""
|
"""Plugin name."""
|
||||||
freeze_trainable_layers: int
|
freeze_trainable_layers: int
|
||||||
"""Freeze trainable layers."""
|
"""Freeze trainable layers."""
|
||||||
freeze_trainable_modules: list[str] | None
|
freeze_trainable_modules: Union[list[str], str]
|
||||||
"""Freeze trainable modules."""
|
"""Freeze trainable modules."""
|
||||||
|
freeze_extra_modules: list[str]
|
||||||
|
"""Freeze extra modules."""
|
||||||
|
cast_trainable_params_to_fp32: bool
|
||||||
|
"""Cast trainable params to fp32."""
|
||||||
|
|
||||||
|
|
||||||
class PeftPlugin(BasePlugin):
|
class PeftPlugin(BasePlugin):
|
||||||
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
|
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
|
||||||
return super().__call__(model, config)
|
return super().__call__(model, config, is_train)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_all_linear_modules(model: HFModel) -> list[str]:
|
||||||
|
r"""Find all available modules to apply LoRA."""
|
||||||
|
forbidden_modules = {"lm_head", "output_layer", "output"}
|
||||||
|
module_names = set()
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
||||||
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
|
return list(module_names)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_adapters(model: HFModel, adapter_name_or_path: Union[list[str], str]) -> HFModel:
|
||||||
|
if not isinstance(adapter_name_or_path, list):
|
||||||
|
adapter_name_or_path = [adapter_name_or_path]
|
||||||
|
|
||||||
|
for adapter_path in adapter_name_or_path:
|
||||||
|
model = PeftModel.from_pretrained(model, adapter_path)
|
||||||
|
model = model.merge_and_unload()
|
||||||
|
logger.info_rank0(f"Merged adapter from {adapter_path}")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is_train: bool) -> HFModel:
|
||||||
|
r"""Loads adapter(s) into the model.
|
||||||
|
|
||||||
|
Determine adapter usage based on mode:
|
||||||
|
- Training: Load the single adapter for continued training.
|
||||||
|
- Inference: Merge all adapters to clean up the model.
|
||||||
|
- Unmergeable: Keep the single adapter active without merging.
|
||||||
|
"""
|
||||||
|
if not isinstance(adapter_name_or_path, list):
|
||||||
|
adapter_name_or_path = [adapter_name_or_path]
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# Adapters fix for deepspeed and quant
|
||||||
|
# Adapters fix for vision
|
||||||
|
|
||||||
|
if is_train and len(adapter_name_or_path) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"When `adapter_name_or_path` is provided for training, only a single LoRA adapter is supported. "
|
||||||
|
"Training will continue on the specified adapter. "
|
||||||
|
"Please merge multiple adapters before starting a new LoRA adapter."
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_train:
|
||||||
|
adapter_to_merge = []
|
||||||
|
adapter_to_resume = adapter_name_or_path[0]
|
||||||
|
else:
|
||||||
|
adapter_to_merge = adapter_name_or_path
|
||||||
|
adapter_to_resume = None
|
||||||
|
|
||||||
|
if adapter_to_merge:
|
||||||
|
model = merge_adapters(model, adapter_to_merge)
|
||||||
|
|
||||||
|
if adapter_to_resume is not None:
|
||||||
|
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_train)
|
||||||
|
if is_train:
|
||||||
|
logger.info_rank0(
|
||||||
|
f"Resuming training from existing LoRA adapter at {adapter_to_resume}. "
|
||||||
|
"LoRA hyperparameters will be loaded from the adapter itself; "
|
||||||
|
"the current LoRA configuration will be ignored. "
|
||||||
|
"Merge the adapter into the base model before training if you want to start a new adapter."
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@PeftPlugin("lora").register()
|
@PeftPlugin("lora").register()
|
||||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
|
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
|
||||||
peft_config = LoraConfig(**config)
|
if model.device.type == "meta":
|
||||||
|
raise ValueError("Currently lora stage does not support loading model by meta.")
|
||||||
|
|
||||||
|
adapter_name_or_path = config.get("adapter_name_or_path")
|
||||||
|
|
||||||
|
if adapter_name_or_path:
|
||||||
|
return load_adapter(model, adapter_name_or_path, is_train)
|
||||||
|
|
||||||
|
logger.info_rank0("Fine-tuning method: LoRA")
|
||||||
|
|
||||||
|
target_modules = config.get("target_modules", "all")
|
||||||
|
|
||||||
|
# Handle target modules
|
||||||
|
if target_modules == "all":
|
||||||
|
target_modules = _find_all_linear_modules(model)
|
||||||
|
elif isinstance(target_modules, str):
|
||||||
|
target_modules = [target_modules]
|
||||||
|
|
||||||
|
logger.info_rank0(f"LoRA target modules: {target_modules}")
|
||||||
|
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
task_type=TaskType.CAUSAL_LM,
|
||||||
|
inference_mode=not is_train,
|
||||||
|
r=config.get("r", 8),
|
||||||
|
lora_alpha=config.get("lora_alpha", 16),
|
||||||
|
lora_dropout=config.get("lora_dropout", 0.05),
|
||||||
|
use_rslora=config.get("use_rslora", False),
|
||||||
|
use_dora=config.get("use_dora", False),
|
||||||
|
target_modules=target_modules,
|
||||||
|
modules_to_save=config.get("modules_to_save", None),
|
||||||
|
)
|
||||||
|
|
||||||
model = get_peft_model(model, peft_config)
|
model = get_peft_model(model, peft_config)
|
||||||
|
|
||||||
|
if is_train:
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@PeftPlugin("freeze").register()
|
@PeftPlugin("freeze").register()
|
||||||
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
|
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool = False) -> HFModel:
|
||||||
raise NotImplementedError()
|
logger.info_rank0("Fine-tuning method: Freeze")
|
||||||
|
|
||||||
|
if not is_train:
|
||||||
|
return model
|
||||||
|
|
||||||
|
freeze_trainable_layers = config.get("freeze_trainable_layers", 2)
|
||||||
|
freeze_trainable_modules = config.get("freeze_trainable_modules", ["all"])
|
||||||
|
freeze_extra_modules = config.get("freeze_extra_modules", [])
|
||||||
|
cast_trainable_params_to_fp32 = config.get("cast_trainable_params_to_fp32", True)
|
||||||
|
|
||||||
|
if isinstance(freeze_trainable_modules, str):
|
||||||
|
freeze_trainable_modules = [module.strip() for module in freeze_trainable_modules.split(",")]
|
||||||
|
|
||||||
|
if isinstance(freeze_extra_modules, str):
|
||||||
|
freeze_extra_modules = [module.strip() for module in freeze_extra_modules.split(",")]
|
||||||
|
|
||||||
|
# Get number of layers
|
||||||
|
num_layers = (
|
||||||
|
getattr(model.config, "num_hidden_layers", None)
|
||||||
|
or getattr(model.config, "num_layers", None)
|
||||||
|
or getattr(model.config, "n_layer", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not num_layers:
|
||||||
|
raise ValueError("Current model does not support freeze tuning.")
|
||||||
|
|
||||||
|
if freeze_trainable_layers > 0:
|
||||||
|
# last n layers
|
||||||
|
trainable_layer_ids = range(max(0, num_layers - freeze_trainable_layers), num_layers)
|
||||||
|
else:
|
||||||
|
# first n layers
|
||||||
|
trainable_layer_ids = range(min(-freeze_trainable_layers, num_layers))
|
||||||
|
|
||||||
|
# Identify hidden and non-hidden modules
|
||||||
|
hidden_modules = set()
|
||||||
|
non_hidden_modules = set()
|
||||||
|
for name, _ in model.named_parameters():
|
||||||
|
if ".0." in name:
|
||||||
|
hidden_modules.add(name.split(".0.")[-1].split(".")[0])
|
||||||
|
elif ".1." in name:
|
||||||
|
hidden_modules.add(name.split(".1.")[-1].split(".")[0])
|
||||||
|
|
||||||
|
if re.search(r"\.\d+\.", name) is None:
|
||||||
|
non_hidden_modules.add(name.split(".")[-2])
|
||||||
|
|
||||||
|
# Build list of trainable layer patterns
|
||||||
|
trainable_layers = []
|
||||||
|
for module_name in freeze_trainable_modules:
|
||||||
|
if module_name == "all":
|
||||||
|
for idx in trainable_layer_ids:
|
||||||
|
trainable_layers.append(f".{idx:d}.")
|
||||||
|
elif module_name in hidden_modules:
|
||||||
|
for idx in trainable_layer_ids:
|
||||||
|
trainable_layers.append(f".{idx:d}.{module_name}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Module {module_name} not found in hidden modules: {hidden_modules}")
|
||||||
|
|
||||||
|
# Add extra modules
|
||||||
|
if freeze_extra_modules:
|
||||||
|
for module_name in freeze_extra_modules:
|
||||||
|
if module_name in non_hidden_modules:
|
||||||
|
trainable_layers.append(module_name)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Module {module_name} not found in non-hidden modules: {non_hidden_modules}")
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# Multi-modal special handling
|
||||||
|
|
||||||
|
# Set requires_grad
|
||||||
|
forbidden_modules = {"quant_state", "quantization_weight", "qweight", "qzeros", "scales"}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(
|
||||||
|
forbidden_module in name for forbidden_module in forbidden_modules
|
||||||
|
):
|
||||||
|
param.requires_grad_(True)
|
||||||
|
if cast_trainable_params_to_fp32:
|
||||||
|
param.data = param.data.to(torch.float32) # Cast to fp32 for stability
|
||||||
|
else:
|
||||||
|
param.requires_grad_(False)
|
||||||
|
|
||||||
|
logger.info_rank0(f"Set trainable layers: {trainable_layers}")
|
||||||
|
|
||||||
|
# Count trainable params for verification
|
||||||
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
all_params = sum(p.numel() for p in model.parameters())
|
||||||
|
logger.info_rank0(
|
||||||
|
f"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params / all_params:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def merge_and_export_model(args: InputArgument = None):
|
||||||
|
model_args, _, _, _ = get_args(args)
|
||||||
|
|
||||||
|
export_config = model_args.peft_config
|
||||||
|
if export_config is None:
|
||||||
|
raise ValueError("Please specify peft_config to merge and export model.")
|
||||||
|
|
||||||
|
export_dir = export_config.get("export_dir")
|
||||||
|
if export_dir is None:
|
||||||
|
raise ValueError("Please specify export_dir.")
|
||||||
|
|
||||||
|
export_size = export_config.get("export_size", 5)
|
||||||
|
export_hub_model_id = export_config.get("export_hub_model_id")
|
||||||
|
infer_dtype = export_config.get("infer_dtype", "auto")
|
||||||
|
export_legacy_format = export_config.get("export_legacy_format", False)
|
||||||
|
|
||||||
|
adapters = None
|
||||||
|
if export_config.get("name") == "lora":
|
||||||
|
adapters = export_config.get("adapter_name_or_path")
|
||||||
|
else:
|
||||||
|
raise ValueError("Currently merge and export model function is only supported for lora.")
|
||||||
|
|
||||||
|
if adapters is None:
|
||||||
|
raise ValueError("Please set adapter_name_or_path to merge adapters into base model.")
|
||||||
|
|
||||||
|
logger.info_rank0("Loading model for export...")
|
||||||
|
model_engine = ModelEngine(model_args, is_train=False)
|
||||||
|
model = model_engine.model
|
||||||
|
tokenizer = model_engine.processor
|
||||||
|
|
||||||
|
if infer_dtype == "auto":
|
||||||
|
if model.config.torch_dtype == torch.float32 and torch.cuda.is_bf16_supported():
|
||||||
|
model = model.to(torch.bfloat16)
|
||||||
|
logger.info_rank0("Converted model to bfloat16.")
|
||||||
|
else:
|
||||||
|
target_dtype = getattr(torch, infer_dtype)
|
||||||
|
model = model.to(target_dtype)
|
||||||
|
logger.info_rank0(f"Converted model to {infer_dtype}.")
|
||||||
|
|
||||||
|
logger.info_rank0(f"Exporting model to {export_dir}...")
|
||||||
|
model.save_pretrained(
|
||||||
|
export_dir,
|
||||||
|
max_shard_size=f"{export_size}GB",
|
||||||
|
safe_serialization=not export_legacy_format,
|
||||||
|
)
|
||||||
|
if tokenizer is not None:
|
||||||
|
try:
|
||||||
|
if hasattr(tokenizer, "padding_side"):
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
tokenizer.save_pretrained(export_dir)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to save tokenizer: {e}")
|
||||||
|
|
||||||
|
if export_hub_model_id:
|
||||||
|
logger.info_rank0(f"Pushing to hub: {export_hub_model_id}...")
|
||||||
|
model.push_to_hub(export_hub_model_id)
|
||||||
|
if tokenizer is not None:
|
||||||
|
tokenizer.push_to_hub(export_hub_model_id)
|
||||||
|
|
||||||
|
logger.info_rank0("Model exported successfully.")
|
||||||
|
|||||||
@@ -0,0 +1,122 @@
|
|||||||
|
# Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import BitsAndBytesConfig
|
||||||
|
|
||||||
|
from ...accelerator.helper import get_current_device
|
||||||
|
from ...config.model_args import ModelArguments
|
||||||
|
from ...utils import logging
|
||||||
|
from ...utils.packages import check_version
|
||||||
|
from ...utils.plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationPlugin(BasePlugin):
|
||||||
|
r"""Plugin for model quantization."""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
init_kwargs: dict[str, Any] = None,
|
||||||
|
config: "PretrainedConfig" = None,
|
||||||
|
tokenizer: "PreTrainedTokenizer" = None,
|
||||||
|
model_args: "ModelArguments" = None,
|
||||||
|
is_trainable: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return super().__call__(
|
||||||
|
init_kwargs, config=config, tokenizer=tokenizer, model_args=model_args, is_trainable=is_trainable
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@QuantizationPlugin("auto").register()
|
||||||
|
def quantization_auto(
|
||||||
|
init_kwargs: dict[str, Any],
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Automatic quantization selection, only support bnb currently.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_kwargs (dict[str, Any]): The kwargs for model initialization.
|
||||||
|
**kwargs: Keyword arguments containing the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: The updated kwargs for model initialization.
|
||||||
|
"""
|
||||||
|
model_args: ModelArguments = kwargs.get("model_args", None)
|
||||||
|
quant_config = model_args.quant_config
|
||||||
|
|
||||||
|
quantization_bit = quant_config.get("quantization_bit", None)
|
||||||
|
if quantization_bit is not None:
|
||||||
|
logger.info_rank0(f"Loading {quantization_bit}-bit quantized model.")
|
||||||
|
if quantization_bit in [8, 4]:
|
||||||
|
return quantization_with_bnb(init_kwargs, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported quantization bit: {quantization_bit} for auto quantization.")
|
||||||
|
logger.warning_rank0("No quantization method applied.")
|
||||||
|
return init_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@QuantizationPlugin("bnb").register()
|
||||||
|
def quantization_with_bnb(
|
||||||
|
init_kwargs: dict[str, Any],
|
||||||
|
model_args: "ModelArguments" = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
r"""Quantization with BNB."""
|
||||||
|
logger.info_rank0("Using Bitsandbytes quantization.")
|
||||||
|
quantization_bit = model_args.quant_config.get("quantization_bit", None)
|
||||||
|
if quantization_bit is None:
|
||||||
|
logger.warning_rank0("quantization_bit is not specified, default to 8-bit quantization.")
|
||||||
|
quantization_bit = 4
|
||||||
|
assert quantization_bit in [8, 4], "Bitsandbytes only accepts 4-bit or 8-bit quantization."
|
||||||
|
if quantization_bit == 8:
|
||||||
|
check_version("bitsandbytes>=0.37.0", mandatory=True)
|
||||||
|
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
elif quantization_bit == 4:
|
||||||
|
check_version("bitsandbytes>=0.39.0", mandatory=True)
|
||||||
|
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_compute_dtype=model_args.quant_config.get("compute_dtype", torch.float16),
|
||||||
|
bnb_4bit_use_double_quant=model_args.quant_config.get("double_quantization", True),
|
||||||
|
bnb_4bit_quant_type=model_args.quant_config.get("quantization_type", "nf4"),
|
||||||
|
bnb_4bit_quant_storage=model_args.quant_config.get(
|
||||||
|
"compute_dtype", torch.float16
|
||||||
|
), # crucial for fsdp+qlora
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
|
||||||
|
|
||||||
|
# TODO: improve deepspeed zero3 and fsdp detection.
|
||||||
|
if kwargs.get("is_trainable", False):
|
||||||
|
logger.info_rank0("Detected inference mode, setting device_map for bitsandbytes quantization.")
|
||||||
|
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||||
|
else:
|
||||||
|
logger.info_rank0("Detected training mode, skip setting device_map for bitsandbytes quantization.")
|
||||||
|
if model_args.quant_config.get("quantization_bit") != 4:
|
||||||
|
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
|
||||||
|
|
||||||
|
check_version("bitsandbytes>=0.43.0", mandatory=True)
|
||||||
|
|
||||||
|
logger.info_rank0(f"Quantizing model to {model_args.quant_config.get('quantization_bit')} bit with bitsandbytes.")
|
||||||
|
return init_kwargs
|
||||||
|
|||||||
@@ -12,224 +12,45 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import importlib
|
||||||
import re
|
|
||||||
|
|
||||||
from ...utils.constants import IGNORE_INDEX
|
from ...utils import logging
|
||||||
from ...utils.helper import get_tokenizer
|
|
||||||
from ...utils.plugin import BasePlugin
|
from ...utils.plugin import BasePlugin
|
||||||
from ...utils.types import Message, ModelInput, Processor, ToolCall
|
from ...utils.types import Message, ModelInput, Processor
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RenderingPlugin(BasePlugin):
|
class RenderingPlugin(BasePlugin):
|
||||||
|
_attempted_template_imports: set[str] = set()
|
||||||
|
|
||||||
|
def _ensure_template_imported(self) -> None:
|
||||||
|
if self.name is None or self.name in self._attempted_template_imports:
|
||||||
|
return
|
||||||
|
|
||||||
|
full_module_name = f"{__package__}.templates.{self.name}"
|
||||||
|
self._attempted_template_imports.add(self.name)
|
||||||
|
try:
|
||||||
|
importlib.import_module(full_module_name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"[Template Registry] Failed to import {full_module_name}: {exc}")
|
||||||
|
|
||||||
|
def __getitem__(self, method_name: str):
|
||||||
|
self._ensure_template_imported()
|
||||||
|
return super().__getitem__(method_name)
|
||||||
|
|
||||||
def render_messages(
|
def render_messages(
|
||||||
self,
|
self,
|
||||||
processor: Processor,
|
processor: Processor,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
tools: str | None = None,
|
tools: str | None = None,
|
||||||
is_generate: bool = False,
|
is_generate: bool = False,
|
||||||
|
enable_thinking: bool = False,
|
||||||
) -> ModelInput:
|
) -> ModelInput:
|
||||||
"""Render messages in the template format."""
|
"""Render messages in the template format."""
|
||||||
return self["render_messages"](processor, messages, tools, is_generate)
|
return self["render_messages"](processor, messages, tools, is_generate, enable_thinking)
|
||||||
|
|
||||||
def parse_messages(self, generated_text: str) -> Message:
|
def parse_messages(self, generated_text: str) -> Message:
|
||||||
"""Parse messages in the template format."""
|
"""Parse messages in the template format."""
|
||||||
return self["parse_messages"](generated_text)
|
return self["parse_messages"](generated_text)
|
||||||
|
|
||||||
|
|
||||||
def _update_model_input(
|
|
||||||
processor: Processor,
|
|
||||||
input_ids: list[int],
|
|
||||||
labels: list[int],
|
|
||||||
loss_weights: list[int],
|
|
||||||
temp_str: str,
|
|
||||||
temp_weight: float,
|
|
||||||
) -> str:
|
|
||||||
"""Update model input with temporary string."""
|
|
||||||
if not temp_str:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
tokenizer = get_tokenizer(processor)
|
|
||||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
|
||||||
input_ids.extend(temp_ids)
|
|
||||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
|
||||||
if temp_weight > 1e-6:
|
|
||||||
labels.extend(temp_ids)
|
|
||||||
else:
|
|
||||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
|
||||||
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
|
||||||
def render_qwen3_nothink_messages(
|
|
||||||
processor: Processor,
|
|
||||||
messages: list[Message],
|
|
||||||
tools: str | None = None,
|
|
||||||
is_generate: bool = False,
|
|
||||||
) -> ModelInput:
|
|
||||||
"""Render messages in the Qwen3 nothink template format.
|
|
||||||
|
|
||||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
|
|
||||||
"""
|
|
||||||
input_ids, labels, loss_weights = [], [], []
|
|
||||||
temp_str, temp_weight = "", 0.0
|
|
||||||
if tools:
|
|
||||||
temp_str += "<|im_start|>system\n"
|
|
||||||
if messages[0]["role"] == "system":
|
|
||||||
for content in messages[0]["content"]:
|
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "\n\n"
|
|
||||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
|
||||||
|
|
||||||
temp_str += (
|
|
||||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
|
||||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
tools = json.loads(tools)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
|
||||||
|
|
||||||
if not isinstance(tools, list):
|
|
||||||
tools = [tools]
|
|
||||||
|
|
||||||
for tool in tools:
|
|
||||||
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
|
||||||
|
|
||||||
temp_str += (
|
|
||||||
"\n</tools>\n\nFor each function call, return a json object with function name "
|
|
||||||
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
|
||||||
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
|
||||||
)
|
|
||||||
elif messages[0]["role"] == "system":
|
|
||||||
temp_str += "<|im_start|>system\n"
|
|
||||||
for content in messages[0]["content"]:
|
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "<|im_end|>\n"
|
|
||||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
|
||||||
|
|
||||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
|
||||||
|
|
||||||
for turn_idx, message in enumerate(messages):
|
|
||||||
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
|
||||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
|
||||||
for content in message["content"]:
|
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "<|im_end|>\n"
|
|
||||||
temp_weight = message.get("loss_weight", 0.0)
|
|
||||||
elif message["role"] == "assistant":
|
|
||||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
|
||||||
for val_idx, content in enumerate(message["content"]):
|
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
elif content["type"] == "reasoning":
|
|
||||||
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
|
||||||
elif content["type"] == "tool_call":
|
|
||||||
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
|
||||||
temp_str += "\n"
|
|
||||||
|
|
||||||
try:
|
|
||||||
tool_call: ToolCall = json.loads(content["value"])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
|
||||||
|
|
||||||
temp_str += (
|
|
||||||
'<tool_call>\n{"name": "'
|
|
||||||
+ tool_call["name"]
|
|
||||||
+ '", "arguments": '
|
|
||||||
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
|
||||||
+ "}\n</tool_call>"
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "<|im_end|>\n"
|
|
||||||
temp_weight = message.get("loss_weight", 1.0)
|
|
||||||
elif message["role"] == "tool":
|
|
||||||
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
|
||||||
temp_str += "<|im_start|>user"
|
|
||||||
|
|
||||||
temp_str += "\n<tool_response>\n"
|
|
||||||
for content in message["content"]:
|
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "\n</tool_response>"
|
|
||||||
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
|
||||||
temp_str += "<|im_end|>\n"
|
|
||||||
|
|
||||||
temp_weight = message.get("loss_weight", 0.0)
|
|
||||||
|
|
||||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
|
||||||
|
|
||||||
if is_generate:
|
|
||||||
temp_str += "<|im_start|>assistant\n"
|
|
||||||
temp_weight = 0.0
|
|
||||||
|
|
||||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
|
||||||
|
|
||||||
attention_mask = [1] * len(input_ids)
|
|
||||||
return ModelInput(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
labels=labels,
|
|
||||||
loss_weights=loss_weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
|
||||||
def parse_qwen3_nothink_message(generated_text: str) -> Message:
|
|
||||||
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
generated_text (str): The generated text in the Qwen3 nothink template format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Message: The parsed message.
|
|
||||||
"""
|
|
||||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
|
||||||
content = []
|
|
||||||
last_end = 0
|
|
||||||
for match in pattern.finditer(generated_text):
|
|
||||||
start, end = match.span()
|
|
||||||
if start > last_end:
|
|
||||||
text = generated_text[last_end:start].strip()
|
|
||||||
if text:
|
|
||||||
content.append({"type": "text", "value": text})
|
|
||||||
|
|
||||||
tag_type = match.group(1)
|
|
||||||
tag_value = match.group(2).strip()
|
|
||||||
if tag_type == "thinking":
|
|
||||||
content.append({"type": "reasoning", "value": tag_value.strip()})
|
|
||||||
elif tag_type == "tool_call":
|
|
||||||
try:
|
|
||||||
json.loads(tag_value.strip())
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
|
||||||
|
|
||||||
content.append({"type": "tool_call", "value": tag_value.strip()})
|
|
||||||
|
|
||||||
last_end = end
|
|
||||||
|
|
||||||
if last_end < len(generated_text):
|
|
||||||
text = generated_text[last_end:].strip()
|
|
||||||
if text:
|
|
||||||
content.append({"type": "text", "value": text})
|
|
||||||
|
|
||||||
return Message(role="assistant", content=content)
|
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
259
src/llamafactory/v1/plugins/model_plugins/templates/qwen3.py
Normal file
259
src/llamafactory/v1/plugins/model_plugins/templates/qwen3.py
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from ....utils.constants import IGNORE_INDEX
|
||||||
|
from ....utils.helper import get_tokenizer
|
||||||
|
from ....utils.types import Message, ModelInput, Processor, ToolCall
|
||||||
|
from ..rendering import RenderingPlugin
|
||||||
|
|
||||||
|
|
||||||
|
def _update_model_input(
|
||||||
|
processor: Processor,
|
||||||
|
input_ids: list[int],
|
||||||
|
labels: list[int],
|
||||||
|
loss_weights: list[int],
|
||||||
|
temp_str: str,
|
||||||
|
temp_weight: float,
|
||||||
|
) -> str:
|
||||||
|
"""Update model input with temporary string."""
|
||||||
|
if not temp_str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(processor)
|
||||||
|
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||||
|
input_ids.extend(temp_ids)
|
||||||
|
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||||
|
if temp_weight > 1e-6:
|
||||||
|
labels.extend(temp_ids)
|
||||||
|
else:
|
||||||
|
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _concat_text_content(message: Message) -> str:
|
||||||
|
"""Concatenate text fields in a message."""
|
||||||
|
message_text = ""
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
message_text += content["value"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
return message_text
|
||||||
|
|
||||||
|
|
||||||
|
def _get_last_query_index(messages: list[Message]) -> int:
|
||||||
|
"""Find the last user query index, excluding wrapped tool responses."""
|
||||||
|
last_query_index = len(messages) - 1
|
||||||
|
for idx in range(len(messages) - 1, -1, -1):
|
||||||
|
message = messages[idx]
|
||||||
|
if message["role"] != "user":
|
||||||
|
continue
|
||||||
|
|
||||||
|
user_text = ""
|
||||||
|
is_plain_text = True
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] != "text":
|
||||||
|
is_plain_text = False
|
||||||
|
break
|
||||||
|
user_text += content["value"]
|
||||||
|
|
||||||
|
if not is_plain_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not (user_text.startswith("<tool_response>") and user_text.endswith("</tool_response>")):
|
||||||
|
last_query_index = idx
|
||||||
|
break
|
||||||
|
|
||||||
|
return last_query_index
|
||||||
|
|
||||||
|
|
||||||
|
def _split_assistant_content(message: Message) -> tuple[str, str, list[ToolCall]]:
|
||||||
|
"""Split assistant message into text, reasoning and tool calls."""
|
||||||
|
text_content = ""
|
||||||
|
reasoning_content = ""
|
||||||
|
tool_calls: list[ToolCall] = []
|
||||||
|
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
text_content += content["value"]
|
||||||
|
elif content["type"] == "reasoning":
|
||||||
|
reasoning_content += content["value"]
|
||||||
|
elif content["type"] == "tool_call":
|
||||||
|
try:
|
||||||
|
tool_call: ToolCall = json.loads(content["value"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||||
|
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
return text_content, reasoning_content, tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
@RenderingPlugin("qwen3").register("render_messages")
|
||||||
|
def render_qwen3_messages(
|
||||||
|
processor: Processor,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: str | None = None,
|
||||||
|
is_generate: bool = False,
|
||||||
|
enable_thinking: bool = False,
|
||||||
|
) -> ModelInput:
|
||||||
|
"""Render messages in the Qwen3 template format.
|
||||||
|
|
||||||
|
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-8B
|
||||||
|
"""
|
||||||
|
input_ids, labels, loss_weights = [], [], []
|
||||||
|
temp_str, temp_weight = "", 0.0
|
||||||
|
if tools:
|
||||||
|
temp_str += "<|im_start|>system\n"
|
||||||
|
if messages[0]["role"] == "system":
|
||||||
|
temp_str += _concat_text_content(messages[0]) + "\n\n"
|
||||||
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||||
|
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
tools = json.loads(tools)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||||
|
|
||||||
|
if not isinstance(tools, list):
|
||||||
|
tools = [tools]
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||||
|
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||||
|
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||||
|
)
|
||||||
|
elif messages[0]["role"] == "system":
|
||||||
|
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
|
||||||
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
last_query_index = _get_last_query_index(messages)
|
||||||
|
|
||||||
|
for turn_idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||||
|
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
|
||||||
|
temp_weight = message.get("loss_weight", 0.0)
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||||
|
|
||||||
|
text_content, reasoning_content, tool_calls = _split_assistant_content(message)
|
||||||
|
if turn_idx > last_query_index and (turn_idx == len(messages) - 1 or reasoning_content):
|
||||||
|
temp_str += "<think>\n" + reasoning_content.strip("\n") + "\n</think>\n\n" + text_content.lstrip("\n")
|
||||||
|
else:
|
||||||
|
temp_str += text_content
|
||||||
|
|
||||||
|
for tool_call_idx, tool_call in enumerate(tool_calls):
|
||||||
|
if (tool_call_idx == 0 and text_content) or tool_call_idx > 0:
|
||||||
|
temp_str += "\n"
|
||||||
|
|
||||||
|
arguments = tool_call.get("arguments")
|
||||||
|
if isinstance(arguments, str):
|
||||||
|
arguments_str = arguments
|
||||||
|
else:
|
||||||
|
arguments_str = json.dumps(arguments, ensure_ascii=False)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
'<tool_call>\n{"name": "'
|
||||||
|
+ tool_call["name"]
|
||||||
|
+ '", "arguments": '
|
||||||
|
+ arguments_str
|
||||||
|
+ "}\n</tool_call>"
|
||||||
|
)
|
||||||
|
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
temp_weight = message.get("loss_weight", 1.0)
|
||||||
|
elif message["role"] == "tool":
|
||||||
|
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||||
|
temp_str += "<|im_start|>user"
|
||||||
|
|
||||||
|
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
|
||||||
|
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
|
||||||
|
temp_weight = message.get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
if is_generate:
|
||||||
|
temp_str += "<|im_start|>assistant\n"
|
||||||
|
temp_weight = 0.0
|
||||||
|
if enable_thinking is False:
|
||||||
|
temp_str += "<think>\n\n</think>\n\n"
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
attention_mask = [1] * len(input_ids)
|
||||||
|
return ModelInput(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
labels=labels,
|
||||||
|
loss_weights=loss_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@RenderingPlugin("qwen3").register("parse_message")
|
||||||
|
def parse_qwen3_message(generated_text: str) -> Message:
|
||||||
|
"""Parse a message in the Qwen3 template format. Supports interleaved reasoning and tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generated_text (str): The generated text in the Qwen3 template format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message: The parsed message.
|
||||||
|
"""
|
||||||
|
pattern = re.compile(r"<(think|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||||
|
content = []
|
||||||
|
last_end = 0
|
||||||
|
|
||||||
|
for match in pattern.finditer(generated_text):
|
||||||
|
start, end = match.span()
|
||||||
|
if start > last_end:
|
||||||
|
text = generated_text[last_end:start].strip()
|
||||||
|
if text:
|
||||||
|
content.append({"type": "text", "value": text})
|
||||||
|
|
||||||
|
tag_type = match.group(1)
|
||||||
|
tag_value = match.group(2).strip()
|
||||||
|
if tag_type == "think":
|
||||||
|
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||||
|
elif tag_type == "tool_call":
|
||||||
|
try:
|
||||||
|
json.loads(tag_value.strip())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||||
|
|
||||||
|
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||||
|
|
||||||
|
last_end = end
|
||||||
|
|
||||||
|
if last_end < len(generated_text):
|
||||||
|
text = generated_text[last_end:].strip()
|
||||||
|
if text:
|
||||||
|
content.append({"type": "text", "value": text})
|
||||||
|
|
||||||
|
return Message(role="assistant", content=content)
|
||||||
@@ -0,0 +1,209 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from ....utils.constants import IGNORE_INDEX
|
||||||
|
from ....utils.helper import get_tokenizer
|
||||||
|
from ....utils.types import Message, ModelInput, Processor, ToolCall
|
||||||
|
from ..rendering import RenderingPlugin
|
||||||
|
|
||||||
|
|
||||||
|
def _update_model_input(
|
||||||
|
processor: Processor,
|
||||||
|
input_ids: list[int],
|
||||||
|
labels: list[int],
|
||||||
|
loss_weights: list[int],
|
||||||
|
temp_str: str,
|
||||||
|
temp_weight: float,
|
||||||
|
) -> str:
|
||||||
|
"""Update model input with temporary string."""
|
||||||
|
if not temp_str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(processor)
|
||||||
|
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||||
|
input_ids.extend(temp_ids)
|
||||||
|
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||||
|
if temp_weight > 1e-6:
|
||||||
|
labels.extend(temp_ids)
|
||||||
|
else:
|
||||||
|
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _concat_text_content(message: Message) -> str:
|
||||||
|
"""Concatenate text fields in a message."""
|
||||||
|
message_text = ""
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
message_text += content["value"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
return message_text
|
||||||
|
|
||||||
|
|
||||||
|
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
||||||
|
def render_qwen3_nothink_messages(
|
||||||
|
processor: Processor,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: str | None = None,
|
||||||
|
is_generate: bool = False,
|
||||||
|
enable_thinking: bool = False,
|
||||||
|
) -> ModelInput:
|
||||||
|
"""Render messages in the Qwen3 nothink template format.
|
||||||
|
|
||||||
|
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507
|
||||||
|
"""
|
||||||
|
input_ids, labels, loss_weights = [], [], []
|
||||||
|
temp_str, temp_weight = "", 0.0
|
||||||
|
if tools:
|
||||||
|
temp_str += "<|im_start|>system\n"
|
||||||
|
if messages[0]["role"] == "system":
|
||||||
|
temp_str += _concat_text_content(messages[0]) + "\n\n"
|
||||||
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||||
|
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tools = json.loads(tools)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||||
|
|
||||||
|
if not isinstance(tools, list):
|
||||||
|
tools = [tools]
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||||
|
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||||
|
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||||
|
)
|
||||||
|
elif messages[0]["role"] == "system":
|
||||||
|
temp_str += "<|im_start|>system\n" + _concat_text_content(messages[0]) + "<|im_end|>\n"
|
||||||
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
for turn_idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||||
|
temp_str += "<|im_start|>" + message["role"] + "\n" + _concat_text_content(message) + "<|im_end|>\n"
|
||||||
|
temp_weight = message.get("loss_weight", 0.0)
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||||
|
for val_idx, content in enumerate(message["content"]):
|
||||||
|
if content["type"] == "text":
|
||||||
|
temp_str += content["value"]
|
||||||
|
elif content["type"] == "reasoning":
|
||||||
|
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
||||||
|
elif content["type"] == "tool_call":
|
||||||
|
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
||||||
|
temp_str += "\n"
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_call: ToolCall = json.loads(content["value"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
'<tool_call>\n{"name": "'
|
||||||
|
+ tool_call["name"]
|
||||||
|
+ '", "arguments": '
|
||||||
|
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
||||||
|
+ "}\n</tool_call>"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
temp_weight = message.get("loss_weight", 1.0)
|
||||||
|
elif message["role"] == "tool":
|
||||||
|
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||||
|
temp_str += "<|im_start|>user"
|
||||||
|
|
||||||
|
temp_str += "\n<tool_response>\n" + _concat_text_content(message) + "\n</tool_response>"
|
||||||
|
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
|
||||||
|
temp_weight = message.get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
if is_generate:
|
||||||
|
temp_str += "<|im_start|>assistant\n"
|
||||||
|
temp_weight = 0.0
|
||||||
|
if enable_thinking:
|
||||||
|
raise ValueError("The qwen3_nothink template does not support thinking mode.")
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
attention_mask = [1] * len(input_ids)
|
||||||
|
return ModelInput(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
labels=labels,
|
||||||
|
loss_weights=loss_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
||||||
|
def parse_qwen3_nothink_message(generated_text: str) -> Message:
|
||||||
|
"""Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generated_text (str): The generated text in the Qwen3 nothink template format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message: The parsed message.
|
||||||
|
"""
|
||||||
|
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||||
|
content = []
|
||||||
|
last_end = 0
|
||||||
|
|
||||||
|
for match in pattern.finditer(generated_text):
|
||||||
|
start, end = match.span()
|
||||||
|
if start > last_end:
|
||||||
|
text = generated_text[last_end:start].strip()
|
||||||
|
if text:
|
||||||
|
content.append({"type": "text", "value": text})
|
||||||
|
|
||||||
|
tag_type = match.group(1)
|
||||||
|
tag_value = match.group(2).strip()
|
||||||
|
if tag_type == "thinking":
|
||||||
|
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||||
|
elif tag_type == "tool_call":
|
||||||
|
try:
|
||||||
|
json.loads(tag_value.strip())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||||
|
|
||||||
|
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||||
|
|
||||||
|
last_end = end
|
||||||
|
|
||||||
|
if last_end < len(generated_text):
|
||||||
|
text = generated_text[last_end:].strip()
|
||||||
|
if text:
|
||||||
|
content.append({"type": "text", "value": text})
|
||||||
|
|
||||||
|
return Message(role="assistant", content=content)
|
||||||
@@ -0,0 +1,129 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""DeepSpeed integration via accelerate's built-in capabilities.
|
||||||
|
|
||||||
|
Instead of manually calling deepspeed.initialize() and syncing config,
|
||||||
|
this module leverages accelerate's Accelerator + DeepSpeedPlugin to handle
|
||||||
|
initialization, backward, gradient accumulation, and model saving.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from accelerate.utils import DeepSpeedPlugin
|
||||||
|
|
||||||
|
from ....utils.logging import get_logger
|
||||||
|
from ....utils.types import HFModel, Processor
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSpeedEngine:
|
||||||
|
"""DeepSpeed integration using accelerate's built-in capabilities.
|
||||||
|
|
||||||
|
This replaces the manual DeepSpeedConfigHelper / DeepSpeedEngine approach
|
||||||
|
with accelerate's Accelerator + DeepSpeedPlugin, which handles:
|
||||||
|
- Config syncing (auto values, batch size, lr, etc.)
|
||||||
|
- deepspeed.initialize() call
|
||||||
|
- Optimizer / LR scheduler wrapping
|
||||||
|
- Backward + gradient accumulation boundary
|
||||||
|
- ZeRO-3 parameter gathering for saving
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dist_config: dict[str, Any], num_micro_batch: int = 1, micro_batch_size: int = 1):
|
||||||
|
config_file = dist_config.get("config_file")
|
||||||
|
if not config_file:
|
||||||
|
raise ValueError("DeepSpeed config_file is required in dist_config")
|
||||||
|
|
||||||
|
ds_plugin = DeepSpeedPlugin(hf_ds_config=config_file)
|
||||||
|
|
||||||
|
self.accelerator = Accelerator(
|
||||||
|
deepspeed_plugin=ds_plugin,
|
||||||
|
gradient_accumulation_steps=num_micro_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resolve "auto" for train_micro_batch_size_per_gpu so that
|
||||||
|
# accelerate.prepare() does not require a DataLoader to infer it.
|
||||||
|
ds_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
|
if ds_config.get("train_micro_batch_size_per_gpu") in (None, "auto"):
|
||||||
|
ds_config["train_micro_batch_size_per_gpu"] = micro_batch_size
|
||||||
|
|
||||||
|
logger.info_rank0(f"DeepSpeedEngine initialized with config: {config_file}")
|
||||||
|
|
||||||
|
def shard_model(self, model: HFModel) -> "DeepSpeedEngine":
|
||||||
|
"""No-op shard — actual model wrapping happens in prepare().
|
||||||
|
|
||||||
|
Returns self so the caller gets the engine instance via the hub interface.
|
||||||
|
"""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self,
|
||||||
|
model: HFModel,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
lr_scheduler: Optional[Any] = None,
|
||||||
|
) -> tuple[HFModel, torch.optim.Optimizer, Any]:
|
||||||
|
"""Prepare model, optimizer, and lr_scheduler using accelerate.
|
||||||
|
|
||||||
|
Internally calls deepspeed.initialize() and wraps the returned objects.
|
||||||
|
"""
|
||||||
|
if lr_scheduler is not None:
|
||||||
|
model, optimizer, lr_scheduler = self.accelerator.prepare(model, optimizer, lr_scheduler)
|
||||||
|
else:
|
||||||
|
model, optimizer = self.accelerator.prepare(model, optimizer)
|
||||||
|
|
||||||
|
model._accelerator = self.accelerator # type: ignore[assignment]
|
||||||
|
|
||||||
|
logger.info_rank0("Model, optimizer, and lr_scheduler prepared via accelerate")
|
||||||
|
return model, optimizer, lr_scheduler
|
||||||
|
|
||||||
|
def backward(self, loss: torch.Tensor) -> None:
|
||||||
|
"""Backward pass using accelerate.
|
||||||
|
|
||||||
|
Delegates to DeepSpeedEngineWrapper.backward() which respects
|
||||||
|
sync_gradients to control gradient accumulation boundaries.
|
||||||
|
When sync_gradients=True: engine.backward(loss) + engine.step()
|
||||||
|
When sync_gradients=False: engine.backward(loss) only
|
||||||
|
"""
|
||||||
|
self.accelerator.backward(loss)
|
||||||
|
|
||||||
|
def get_grad_norm(self) -> float:
|
||||||
|
"""Get the global gradient norm from the DeepSpeed engine."""
|
||||||
|
engine_wrapper = getattr(self.accelerator, "deepspeed_engine_wrapped", None)
|
||||||
|
if engine_wrapper is not None:
|
||||||
|
return engine_wrapper.engine.get_global_grad_norm() or 0.0
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||||
|
"""Save model using accelerate's built-in ZeRO-aware utilities.
|
||||||
|
|
||||||
|
Expects model._accelerator to be set during prepare().
|
||||||
|
Handles ZeRO-3 parameter gathering automatically via
|
||||||
|
accelerator.get_state_dict().
|
||||||
|
"""
|
||||||
|
accelerator: Accelerator = model._accelerator # type: ignore[union-attr]
|
||||||
|
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
state_dict = accelerator.get_state_dict(model)
|
||||||
|
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
unwrapped_model.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
|
||||||
|
processor.save_pretrained(output_dir, max_shard_size="4GB")
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
logger.info_rank0(f"Model saved to {output_dir}")
|
||||||
|
|||||||
@@ -12,28 +12,30 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from peft.tuners.lora import LoraLayer
|
||||||
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
|
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
|
||||||
from torch.distributed.fsdp import (
|
from torch.distributed.fsdp import (
|
||||||
CPUOffloadPolicy,
|
CPUOffloadPolicy,
|
||||||
MixedPrecisionPolicy,
|
MixedPrecisionPolicy,
|
||||||
fully_shard,
|
fully_shard,
|
||||||
)
|
)
|
||||||
from transformers import PreTrainedModel
|
|
||||||
|
|
||||||
from ....accelerator.helper import get_current_accelerator
|
from ....accelerator.helper import get_current_accelerator
|
||||||
from ....accelerator.interface import DistributedInterface
|
from ....accelerator.interface import DistributedInterface
|
||||||
from ....utils.logging import get_logger
|
from ....utils.logging import get_logger
|
||||||
|
from ....utils.types import HFModel, Processor
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
|
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
|
||||||
no_split_modules = getattr(model, "_no_split_modules", None)
|
no_split_modules = getattr(model, "_no_split_modules", None)
|
||||||
if no_split_modules:
|
if no_split_modules:
|
||||||
if isinstance(no_split_modules, (list, tuple)):
|
if isinstance(no_split_modules, (list, tuple)):
|
||||||
@@ -49,6 +51,20 @@ def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||||
|
if DistributedInterface().get_rank() == 0:
|
||||||
|
logger.info("Gathering state dict for saving...")
|
||||||
|
|
||||||
|
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
|
||||||
|
state_dict = get_model_state_dict(model, options=options)
|
||||||
|
|
||||||
|
if DistributedInterface().get_rank() == 0:
|
||||||
|
model_to_save = model.module if hasattr(model, "module") else model
|
||||||
|
model_to_save.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
|
||||||
|
processor.save_pretrained(output_dir, max_shard_size="4GB")
|
||||||
|
logger.info(f"Model saved to {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
class FSDP2Engine:
|
class FSDP2Engine:
|
||||||
def __init__(self, dist_config: dict):
|
def __init__(self, dist_config: dict):
|
||||||
self.dist_interface = DistributedInterface()
|
self.dist_interface = DistributedInterface()
|
||||||
@@ -94,7 +110,10 @@ class FSDP2Engine:
|
|||||||
cast_forward_inputs=True,
|
cast_forward_inputs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel:
|
def is_lora_module_wrap(self, model) -> bool:
|
||||||
|
return any(isinstance(module, LoraLayer) for module in model.modules())
|
||||||
|
|
||||||
|
def prepare_model(self, model: HFModel) -> HFModel:
|
||||||
if self.fsdp_mesh is None:
|
if self.fsdp_mesh is None:
|
||||||
logger.warning("No FSDP Mesh available, skipping FSDP wrapping.")
|
logger.warning("No FSDP Mesh available, skipping FSDP wrapping.")
|
||||||
return model
|
return model
|
||||||
@@ -111,6 +130,25 @@ class FSDP2Engine:
|
|||||||
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
|
logger.info(f"Applying per-layer FSDP to {layer_cls.__name__}")
|
||||||
transformer_layer_cls_to_wrap = {layer_cls}
|
transformer_layer_cls_to_wrap = {layer_cls}
|
||||||
|
|
||||||
|
if self.is_lora_module_wrap(model):
|
||||||
|
lora_modules = []
|
||||||
|
for module in model.modules():
|
||||||
|
if len(list(module.children())) != 0:
|
||||||
|
continue
|
||||||
|
if any(param.requires_grad for param in module.parameters(recurse=False)):
|
||||||
|
lora_modules.append(module)
|
||||||
|
|
||||||
|
for module in lora_modules:
|
||||||
|
fully_shard(
|
||||||
|
module,
|
||||||
|
mesh=self.fsdp_mesh,
|
||||||
|
reshard_after_forward=self.reshard_after_forward,
|
||||||
|
mp_policy=mp_policy,
|
||||||
|
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Applying FSDP wrap for LoRA layer separately.")
|
||||||
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
should_wrap = False
|
should_wrap = False
|
||||||
|
|
||||||
@@ -129,12 +167,11 @@ class FSDP2Engine:
|
|||||||
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
|
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
use_gradient_checkpointing = True # Could be configurable
|
# BaseTrainer is the single source of truth for gradient checkpointing.
|
||||||
if use_gradient_checkpointing:
|
# FSDP2 only applies the input-grad compatibility hook when checkpointing is already enabled.
|
||||||
|
if getattr(model, "is_gradient_checkpointing", False):
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
logger.info("Enabling gradient checkpointing (transformers native)...")
|
logger.info("Gradient checkpointing is enabled. Applying FSDP2 input grad preparation.")
|
||||||
|
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
|
||||||
|
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
model.enable_input_require_grads()
|
model.enable_input_require_grads()
|
||||||
@@ -156,7 +193,7 @@ class FSDP2Engine:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None):
|
def materialize_and_load(self, model: HFModel, hf_model_path: str, dcp_path: str = None):
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
logger.info("Materializing sharded model params...")
|
logger.info("Materializing sharded model params...")
|
||||||
|
|
||||||
@@ -176,15 +213,57 @@ class FSDP2Engine:
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def shard_model(self, model: PreTrainedModel) -> PreTrainedModel:
|
def _save_non_persistent_buffers(self, model: HFModel) -> dict:
|
||||||
|
"""Save non-persistent buffers, such as inv_freq."""
|
||||||
|
saved = {}
|
||||||
|
for mod_name, module in model.named_modules():
|
||||||
|
for buf_name in module._non_persistent_buffers_set:
|
||||||
|
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
|
||||||
|
buf = getattr(module, buf_name, None)
|
||||||
|
if buf is not None:
|
||||||
|
saved[fqn] = copy.deepcopy(buf)
|
||||||
|
if self.rank == 0 and saved:
|
||||||
|
logger.info(f"Saved {len(saved)} non-persistent buffers")
|
||||||
|
return saved
|
||||||
|
|
||||||
|
def _restore_non_persistent_buffers(self, model: HFModel, saved_buffers: dict):
|
||||||
|
"""Register saved non-persistent buffers to model."""
|
||||||
|
if not saved_buffers:
|
||||||
|
return
|
||||||
|
device = get_current_accelerator()
|
||||||
|
for fqn, buf in saved_buffers.items():
|
||||||
|
buf = buf.to(device)
|
||||||
|
if "." in fqn:
|
||||||
|
parent_fqn, buf_name = fqn.rsplit(".", 1)
|
||||||
|
parent_module = model.get_submodule(parent_fqn)
|
||||||
|
else:
|
||||||
|
buf_name = fqn
|
||||||
|
parent_module = model
|
||||||
|
parent_module.register_buffer(buf_name, buf, persistent=False)
|
||||||
|
if self.rank == 0:
|
||||||
|
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
|
||||||
|
|
||||||
|
def shard_model(self, model: HFModel) -> HFModel:
|
||||||
if model.device.type == "meta":
|
if model.device.type == "meta":
|
||||||
|
non_persistent_buffers = self._save_non_persistent_buffers(model)
|
||||||
|
|
||||||
|
if getattr(model.config, "tie_word_embeddings", None):
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
model = self.prepare_model(model)
|
model = self.prepare_model(model)
|
||||||
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
|
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
|
||||||
|
|
||||||
|
# fix tied broken for no-fsdp-wrap case
|
||||||
|
if getattr(model.config, "tie_word_embeddings", None):
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
|
self._restore_non_persistent_buffers(model, non_persistent_buffers)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model = self.prepare_model(model)
|
model = self.prepare_model(model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str):
|
def _load_from_dcp(self, model: HFModel, dcp_path: str):
|
||||||
import torch.distributed.checkpoint as dcp
|
import torch.distributed.checkpoint as dcp
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -203,7 +282,7 @@ class FSDP2Engine:
|
|||||||
logger.error(f"Failed to load from DCP: {e}")
|
logger.error(f"Failed to load from DCP: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _load_weights_from_hf_checkpoint(self, model, hf_model_path):
|
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|||||||
@@ -12,9 +12,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ....config.arg_utils import PluginConfig
|
from ....config.arg_utils import PluginConfig
|
||||||
from ....utils.plugin import BasePlugin
|
from ....utils.plugin import BasePlugin
|
||||||
from ....utils.types import HFModel
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ....utils.types import HFModel, Processor
|
||||||
|
|
||||||
|
|
||||||
class DistributedPlugin(BasePlugin):
|
class DistributedPlugin(BasePlugin):
|
||||||
@@ -23,12 +30,32 @@ class DistributedPlugin(BasePlugin):
|
|||||||
|
|
||||||
|
|
||||||
@DistributedPlugin("fsdp2").register()
|
@DistributedPlugin("fsdp2").register()
|
||||||
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig) -> HFModel:
|
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
|
||||||
from .fsdp2 import FSDP2Engine
|
from .fsdp2 import FSDP2Engine
|
||||||
|
|
||||||
return FSDP2Engine(dist_config).shard_model(model)
|
return FSDP2Engine(dist_config).shard_model(model)
|
||||||
|
|
||||||
|
|
||||||
|
@DistributedPlugin("fsdp2").register("save_model")
|
||||||
|
def save_model_fsdp2(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||||
|
from .fsdp2 import save_model
|
||||||
|
|
||||||
|
return save_model(model, output_dir, processor)
|
||||||
|
|
||||||
|
|
||||||
@DistributedPlugin("deepspeed").register()
|
@DistributedPlugin("deepspeed").register()
|
||||||
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig) -> HFModel:
|
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
|
||||||
return model
|
from .deepspeed import DeepSpeedEngine
|
||||||
|
|
||||||
|
return DeepSpeedEngine(
|
||||||
|
dist_config,
|
||||||
|
num_micro_batch=kwargs.get("num_micro_batch"),
|
||||||
|
micro_batch_size=kwargs.get("micro_batch_size"),
|
||||||
|
).shard_model(model)
|
||||||
|
|
||||||
|
|
||||||
|
@DistributedPlugin("deepspeed").register("save_model")
|
||||||
|
def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||||
|
from .deepspeed import save_model
|
||||||
|
|
||||||
|
return save_model(model, output_dir, processor)
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ def run_sft(args: InputArgument = None):
|
|||||||
model_args, data_args, training_args, _ = get_args(args)
|
model_args, data_args, training_args, _ = get_args(args)
|
||||||
DistributedInterface(training_args.dist_config)
|
DistributedInterface(training_args.dist_config)
|
||||||
train_dataset = DataEngine(data_args.train_dataset)
|
train_dataset = DataEngine(data_args.train_dataset)
|
||||||
model_engine = ModelEngine(model_args)
|
model_engine = ModelEngine(model_args, is_train=True)
|
||||||
trainer = SFTTrainer(
|
trainer = SFTTrainer(
|
||||||
args=training_args,
|
args=training_args,
|
||||||
model=model_engine.model,
|
model=model_engine.model,
|
||||||
|
|||||||
@@ -15,12 +15,22 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
from transformers import set_seed as hf_set_seed
|
||||||
|
|
||||||
from ..accelerator.interface import DistributedInterface
|
from ..accelerator.interface import DistributedInterface
|
||||||
from .constants import IGNORE_INDEX
|
from .constants import IGNORE_INDEX
|
||||||
from .types import BatchInput, ModelInput, Processor, Tensor
|
from .types import BatchInput, ModelInput, Processor, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed: int) -> None:
|
||||||
|
"""Set seed for reproducibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: Random seed.
|
||||||
|
"""
|
||||||
|
hf_set_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
def is_tokenizer(processor: Processor) -> bool:
|
def is_tokenizer(processor: Processor) -> bool:
|
||||||
"""Check if processor is tokenizer.
|
"""Check if processor is tokenizer.
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,13 @@ from functools import lru_cache
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from . import logging
|
||||||
|
from .env import is_env_enabled
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -41,3 +48,22 @@ def _get_package_version(name: str) -> "Version":
|
|||||||
@lru_cache
|
@lru_cache
|
||||||
def is_transformers_version_greater_than(content: str):
|
def is_transformers_version_greater_than(content: str):
|
||||||
return _get_package_version("transformers") >= version.parse(content)
|
return _get_package_version("transformers") >= version.parse(content)
|
||||||
|
|
||||||
|
|
||||||
|
def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||||
|
r"""Optionally check the package version."""
|
||||||
|
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
|
||||||
|
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if "gptqmodel" in requirement or "autoawq" in requirement:
|
||||||
|
pip_command = f"pip install {requirement} --no-build-isolation"
|
||||||
|
else:
|
||||||
|
pip_command = f"pip install {requirement}"
|
||||||
|
|
||||||
|
if mandatory:
|
||||||
|
hint = f"To fix: run `{pip_command}`."
|
||||||
|
else:
|
||||||
|
hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
|
||||||
|
|
||||||
|
require_version(requirement, hint)
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class DistributedConfig(TypedDict, total=False):
|
|||||||
|
|
||||||
|
|
||||||
class Content(TypedDict):
|
class Content(TypedDict):
|
||||||
type: Literal["text", "reasoning", "tool_call", "image_url"]
|
type: Literal["text", "reasoning", "tool_call", "image_url", "video_url", "audio_url"]
|
||||||
"""Type of the content."""
|
"""Type of the content."""
|
||||||
value: str
|
value: str
|
||||||
"""Value of the content."""
|
"""Value of the content."""
|
||||||
|
|||||||
@@ -108,11 +108,26 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
enable_thinking = gr.Checkbox(value=True)
|
enable_thinking = gr.Checkbox(value=True)
|
||||||
report_to = gr.Dropdown(
|
report_to = gr.Dropdown(
|
||||||
choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "all"],
|
choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "trackio", "all"],
|
||||||
value="none",
|
value="none",
|
||||||
allow_custom_value=True,
|
allow_custom_value=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with gr.Accordion("Trackio Settings", open=False):
|
||||||
|
project = gr.Textbox(
|
||||||
|
value="huggingface",
|
||||||
|
label="Project Name",
|
||||||
|
info="Project name for experiment tracking (used by Trackio, W&B, etc.)",
|
||||||
|
)
|
||||||
|
|
||||||
|
trackio_space_id = gr.Textbox(
|
||||||
|
value="trackio", label="Trackio Space ID", info="Hugging Face Space ID for Trackio deployment"
|
||||||
|
)
|
||||||
|
|
||||||
|
hub_private_repo = gr.Checkbox(
|
||||||
|
value=False, label="Private Repository", info="Make the Hugging Face repository private"
|
||||||
|
)
|
||||||
|
|
||||||
input_elems.update(
|
input_elems.update(
|
||||||
{
|
{
|
||||||
logging_steps,
|
logging_steps,
|
||||||
@@ -128,6 +143,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
|
|||||||
use_llama_pro,
|
use_llama_pro,
|
||||||
enable_thinking,
|
enable_thinking,
|
||||||
report_to,
|
report_to,
|
||||||
|
project,
|
||||||
|
trackio_space_id,
|
||||||
|
hub_private_repo,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elem_dict.update(
|
elem_dict.update(
|
||||||
@@ -146,6 +164,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
|
|||||||
use_llama_pro=use_llama_pro,
|
use_llama_pro=use_llama_pro,
|
||||||
enable_thinking=enable_thinking,
|
enable_thinking=enable_thinking,
|
||||||
report_to=report_to,
|
report_to=report_to,
|
||||||
|
project=project,
|
||||||
|
trackio_space_id=trackio_space_id,
|
||||||
|
hub_private_repo=hub_private_repo,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -166,3 +166,33 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
|||||||
def fix_valuehead_cpu_loading():
|
def fix_valuehead_cpu_loading():
|
||||||
"""Fix valuehead model loading."""
|
"""Fix valuehead model loading."""
|
||||||
patch_valuehead_model()
|
patch_valuehead_model()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def bypass_mistral_regex_check():
|
||||||
|
"""Disable Mistral regex network check.
|
||||||
|
|
||||||
|
Monkey-patch TokenizersBackend._patch_mistral_regex into a no-op.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from transformers.tokenization_utils_fast import TokenizersBackend
|
||||||
|
except ImportError:
|
||||||
|
# Very old transformers, nothing to patch
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
if not hasattr(TokenizersBackend, "_patch_mistral_regex"):
|
||||||
|
# Method does not exist in this version
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
# Backup original method
|
||||||
|
original = TokenizersBackend._patch_mistral_regex
|
||||||
|
|
||||||
|
# Replace with no-op
|
||||||
|
TokenizersBackend._patch_mistral_regex = lambda cls, tokenizer, *args, **kwargs: tokenizer
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Restore original method
|
||||||
|
TokenizersBackend._patch_mistral_regex = original
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@@ -22,6 +23,7 @@ from transformers import AutoConfig, AutoModelForImageTextToText
|
|||||||
from llamafactory.data import get_template_and_fix_tokenizer
|
from llamafactory.data import get_template_and_fix_tokenizer
|
||||||
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
|
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
|
||||||
from llamafactory.extras.constants import IGNORE_INDEX
|
from llamafactory.extras.constants import IGNORE_INDEX
|
||||||
|
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||||
from llamafactory.hparams import get_infer_args
|
from llamafactory.hparams import get_infer_args
|
||||||
from llamafactory.model import load_tokenizer
|
from llamafactory.model import load_tokenizer
|
||||||
|
|
||||||
@@ -116,19 +118,189 @@ def test_multimodal_collator():
|
|||||||
"labels": [
|
"labels": [
|
||||||
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
|
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
|
||||||
],
|
],
|
||||||
"position_ids": [
|
"position_ids": [[[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0]]] * 3,
|
||||||
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
"rope_deltas": [[0]],
|
||||||
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
|
||||||
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
|
||||||
],
|
|
||||||
"rope_deltas": [[-8]],
|
|
||||||
**tokenizer_module["processor"].image_processor(fake_image),
|
**tokenizer_module["processor"].image_processor(fake_image),
|
||||||
}
|
}
|
||||||
|
if not is_transformers_version_greater_than("5.0.0"):
|
||||||
|
# adapt position_ids and rope_deltas for transformers < 5.0.0
|
||||||
|
# https://github.com/huggingface/transformers/pull/43972
|
||||||
|
expected_input["position_ids"] = [[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]]] * 3
|
||||||
|
expected_input["rope_deltas"] = [[-8]]
|
||||||
|
|
||||||
assert batch_input.keys() == expected_input.keys()
|
assert batch_input.keys() == expected_input.keys()
|
||||||
for k in batch_input.keys():
|
for k in batch_input.keys():
|
||||||
|
if k == "position_ids" and batch_input[k].dim() == 3 and batch_input[k].shape[0] == 4:
|
||||||
|
batch_input[k] = batch_input[k][1:]
|
||||||
|
|
||||||
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_packed_feature(
|
||||||
|
*,
|
||||||
|
packing_params: dict,
|
||||||
|
pad_token_id: int,
|
||||||
|
label_ignore_id: int,
|
||||||
|
fake_image: Image.Image,
|
||||||
|
vision_start_id: int | None = None,
|
||||||
|
vision_end_id: int | None = None,
|
||||||
|
image_pad_id: int | None = None,
|
||||||
|
) -> dict:
|
||||||
|
r"""Build one packed sample using the new PackingParams schema."""
|
||||||
|
sequence_boundaries = packing_params["sequence_boundaries"]
|
||||||
|
image_subseq_ids = packing_params["image_subseq_ids"]
|
||||||
|
video_subseq_ids = packing_params["video_subseq_ids"]
|
||||||
|
audio_subseq_ids = packing_params["audio_subseq_ids"]
|
||||||
|
unpadded_length = packing_params["unpadded_length"]
|
||||||
|
right_padding_length = packing_params["right_padding_length"] # which only preserved in tests
|
||||||
|
cutoff_plus_one = sequence_boundaries[-1]
|
||||||
|
content_len = unpadded_length
|
||||||
|
pad_len = right_padding_length
|
||||||
|
assert content_len + pad_len == cutoff_plus_one
|
||||||
|
assert sequence_boundaries[0] == 0
|
||||||
|
assert sequence_boundaries[-1] == cutoff_plus_one
|
||||||
|
|
||||||
|
content_ids = list(range(100, 100 + content_len))
|
||||||
|
if vision_start_id is not None and vision_end_id is not None and image_pad_id is not None:
|
||||||
|
image_counts_by_subseq = Counter(image_subseq_ids)
|
||||||
|
for subseq_idx, image_count in sorted(image_counts_by_subseq.items()):
|
||||||
|
if subseq_idx >= len(sequence_boundaries) - 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
subseq_start = sequence_boundaries[subseq_idx]
|
||||||
|
subseq_end = sequence_boundaries[subseq_idx + 1]
|
||||||
|
subseq_len = subseq_end - subseq_start
|
||||||
|
if subseq_len < 3:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build repeated image groups while preserving at least 3 tokens for each remaining image.
|
||||||
|
injected_tokens: list[int] = []
|
||||||
|
remaining = subseq_len
|
||||||
|
for image_idx in range(image_count):
|
||||||
|
remaining_images = image_count - image_idx
|
||||||
|
min_reserved_for_rest = 3 * (remaining_images - 1)
|
||||||
|
current_group_len = min(6, remaining - min_reserved_for_rest)
|
||||||
|
if current_group_len < 3:
|
||||||
|
break
|
||||||
|
|
||||||
|
group = [vision_start_id] + [image_pad_id] * max(1, current_group_len - 2) + [vision_end_id]
|
||||||
|
injected_tokens.extend(group[:current_group_len])
|
||||||
|
remaining -= current_group_len
|
||||||
|
|
||||||
|
if injected_tokens:
|
||||||
|
insert_end = subseq_start + len(injected_tokens)
|
||||||
|
content_ids[subseq_start:insert_end] = injected_tokens
|
||||||
|
|
||||||
|
input_ids = content_ids + [pad_token_id] * pad_len
|
||||||
|
attention_mask = [1] * content_len + [0] * pad_len
|
||||||
|
labels = [label_ignore_id] * cutoff_plus_one
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"labels": labels,
|
||||||
|
"images": [fake_image] * len(image_subseq_ids),
|
||||||
|
"videos": [None] * len(video_subseq_ids),
|
||||||
|
"audios": [None] * len(audio_subseq_ids),
|
||||||
|
"packing_params": packing_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_packed_features(
|
||||||
|
*,
|
||||||
|
packing_params: dict,
|
||||||
|
pad_token_id: int,
|
||||||
|
label_ignore_id: int,
|
||||||
|
fake_image: Image.Image,
|
||||||
|
vision_start_id: int,
|
||||||
|
vision_end_id: int,
|
||||||
|
image_pad_id: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
r"""Build packed features from caller-provided packing_params."""
|
||||||
|
return [
|
||||||
|
_make_packed_feature(
|
||||||
|
packing_params=packing_params,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
label_ignore_id=label_ignore_id,
|
||||||
|
fake_image=fake_image,
|
||||||
|
vision_start_id=vision_start_id,
|
||||||
|
vision_end_id=vision_end_id,
|
||||||
|
image_pad_id=image_pad_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_expected_position_ids(packing_params, get_rope_func, input_ids, attention_mask) -> torch.Tensor:
|
||||||
|
bound_list = packing_params["sequence_boundaries"]
|
||||||
|
input_ids_slices = [input_ids[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)]
|
||||||
|
attention_mask_slices = [attention_mask[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)]
|
||||||
|
img_counts_by_subseq = Counter(packing_params["image_subseq_ids"])
|
||||||
|
all_position_ids = []
|
||||||
|
for i, input_ids_slice in enumerate(input_ids_slices):
|
||||||
|
img_cnt = img_counts_by_subseq[i]
|
||||||
|
if sum(attention_mask_slices[i]) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
rope_func_kwargs = {
|
||||||
|
"input_ids": torch.tensor(input_ids_slice).unsqueeze(0),
|
||||||
|
"attention_mask": torch.tensor(attention_mask_slices[i]).unsqueeze(0),
|
||||||
|
"image_grid_thw": [torch.tensor([1, 4, 4])] * img_cnt,
|
||||||
|
}
|
||||||
|
position_ids, _ = get_rope_func(**rope_func_kwargs)
|
||||||
|
all_position_ids.append(position_ids)
|
||||||
|
|
||||||
|
return torch.cat(all_position_ids, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
def test_multimodal_collator_with_packing():
|
||||||
|
model_args, data_args, *_ = get_infer_args(
|
||||||
|
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}
|
||||||
|
)
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||||
|
tokenizer_module["tokenizer"].padding_side = "right"
|
||||||
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = AutoModelForImageTextToText.from_config(config)
|
||||||
|
|
||||||
|
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||||
|
template=template,
|
||||||
|
model=model,
|
||||||
|
pad_to_multiple_of=4,
|
||||||
|
label_pad_token_id=IGNORE_INDEX,
|
||||||
|
**tokenizer_module,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
|
packing_params = {
|
||||||
|
"sequence_boundaries": [0, 2, 10, 18, 28, 32],
|
||||||
|
"image_subseq_ids": [1, 2, 3],
|
||||||
|
"video_subseq_ids": [],
|
||||||
|
"audio_subseq_ids": [],
|
||||||
|
"unpadded_length": 28,
|
||||||
|
"right_padding_length": 4,
|
||||||
|
}
|
||||||
|
fake_image = Image.new("RGB", (64, 64), (255, 255, 255))
|
||||||
|
features = _make_packed_features(
|
||||||
|
packing_params=packing_params,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
label_ignore_id=IGNORE_INDEX,
|
||||||
|
fake_image=fake_image,
|
||||||
|
vision_start_id=tokenizer.convert_tokens_to_ids("<|vision_start|>"),
|
||||||
|
vision_end_id=tokenizer.convert_tokens_to_ids("<|vision_end|>"),
|
||||||
|
image_pad_id=tokenizer.convert_tokens_to_ids("<|image_pad|>"),
|
||||||
|
)
|
||||||
|
expected_position_ids = _get_expected_position_ids(
|
||||||
|
packing_params,
|
||||||
|
data_collator.get_rope_func,
|
||||||
|
features[0]["input_ids"],
|
||||||
|
features[0]["attention_mask"],
|
||||||
|
)
|
||||||
|
batch_input = data_collator(features) # [3, bsz, seq_len]
|
||||||
|
valid_len = expected_position_ids.shape[-1]
|
||||||
|
assert batch_input["position_ids"][1:, :, :valid_len].eq(expected_position_ids).all()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu"])
|
@pytest.mark.runs_on(["cpu"])
|
||||||
def test_4d_attention_mask():
|
def test_4d_attention_mask():
|
||||||
o = 0.0
|
o = 0.0
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
# change if test fails or cache is outdated
|
# change if test fails or cache is outdated
|
||||||
0.9.5.106
|
0.9.5.107
|
||||||
|
|||||||
@@ -172,3 +172,33 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
|||||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||||
elif CURRENT_DEVICE == "npu":
|
elif CURRENT_DEVICE == "npu":
|
||||||
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
|
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def bypass_mistral_regex_check():
|
||||||
|
"""Disable Mistral regex network check.
|
||||||
|
|
||||||
|
Monkey-patch TokenizersBackend._patch_mistral_regex into a no-op.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from transformers.tokenization_utils_fast import TokenizersBackend
|
||||||
|
except ImportError:
|
||||||
|
# Very old transformers, nothing to patch
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
if not hasattr(TokenizersBackend, "_patch_mistral_regex"):
|
||||||
|
# Method does not exist in this version
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
# Backup original method
|
||||||
|
original = TokenizersBackend._patch_mistral_regex
|
||||||
|
|
||||||
|
# Replace with no-op
|
||||||
|
TokenizersBackend._patch_mistral_regex = lambda cls, tokenizer, *args, **kwargs: tokenizer
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Restore original method
|
||||||
|
TokenizersBackend._patch_mistral_regex = original
|
||||||
|
|||||||
156
tests_v1/plugins/model_plugins/test_peft.py
Normal file
156
tests_v1/plugins/model_plugins/test_peft.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from llamafactory.v1.plugins.model_plugins import peft as peft_module
|
||||||
|
from llamafactory.v1.plugins.model_plugins.peft import merge_and_export_model
|
||||||
|
|
||||||
|
|
||||||
|
TINY_MODEL = "llamafactory/tiny-random-qwen3"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def model_path():
|
||||||
|
return TINY_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def model(model_path):
|
||||||
|
return AutoModelForCausalLM.from_pretrained(model_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def tokenizer(model_path):
|
||||||
|
return AutoTokenizer.from_pretrained(model_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def adapter_path(tmp_path):
|
||||||
|
# Create a dummy adapter
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=8,
|
||||||
|
lora_alpha=16,
|
||||||
|
target_modules=["q_proj", "v_proj"],
|
||||||
|
lora_dropout=0.05,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(TINY_MODEL)
|
||||||
|
peft_model = get_peft_model(base_model, lora_config)
|
||||||
|
save_path = tmp_path / "test_adapter"
|
||||||
|
peft_model.save_pretrained(save_path)
|
||||||
|
return str(save_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_all_linear_modules(model):
|
||||||
|
"""Verify linear modules are discoverable and include q_proj / v_proj for tiny-random-qwen3."""
|
||||||
|
modules = peft_module._find_all_linear_modules(model)
|
||||||
|
expected_subset = {"q_proj", "v_proj"}
|
||||||
|
assert expected_subset.issubset(set(modules))
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_lora_model(model):
|
||||||
|
"""Verify a PeftModel is returned and LoRA config takes effect."""
|
||||||
|
config = {"name": "lora", "r": 8, "target_modules": "all", "lora_alpha": 16}
|
||||||
|
model = peft_module.get_lora_model(model, config, is_train=True)
|
||||||
|
assert isinstance(model, PeftModel)
|
||||||
|
assert model.peft_config["default"].r == 8
|
||||||
|
assert "q_proj" in model.peft_config["default"].target_modules
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_freeze_model_layers(model):
|
||||||
|
"""Verify layer-wise freezing: only the last layer stays trainable."""
|
||||||
|
# Freeze all but last layer
|
||||||
|
config = {"name": "freeze", "freeze_trainable_layers": 1, "freeze_trainable_modules": "all"}
|
||||||
|
|
||||||
|
# Ensure we start with something known
|
||||||
|
model = peft_module.get_freeze_model(model, config, is_train=True)
|
||||||
|
|
||||||
|
num_layers = model.config.num_hidden_layers
|
||||||
|
assert num_layers > 0
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if f"layers.{num_layers - 1}" in name:
|
||||||
|
assert param.requires_grad, f"{name} should be trainable"
|
||||||
|
elif "layers.0" in name and num_layers > 1:
|
||||||
|
assert not param.requires_grad, f"{name} should be frozen"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_freeze_model_modules(model):
|
||||||
|
"""Verify module-wise freezing: only last-layer self_attn is trainable."""
|
||||||
|
# Freeze specific modules (e.g. only self_attn)
|
||||||
|
config = {"name": "freeze", "freeze_trainable_layers": 1, "freeze_trainable_modules": "self_attn"}
|
||||||
|
model = peft_module.get_freeze_model(model, config, is_train=True)
|
||||||
|
|
||||||
|
num_layers = model.config.num_hidden_layers
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if f"layers.{num_layers - 1}" in name and "self_attn" in name:
|
||||||
|
assert param.requires_grad, f"{name} should be trainable"
|
||||||
|
else:
|
||||||
|
assert not param.requires_grad, f"{name} should be frozen"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_adapter_single_for_inference(model, adapter_path):
|
||||||
|
"""Verify single adapter is merged+unloaded in inference mode."""
|
||||||
|
# Test loading single adapter for inference (merge and unload)
|
||||||
|
model_result = peft_module.load_adapter(model, adapter_path, is_train=False)
|
||||||
|
assert not isinstance(model_result, PeftModel)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_adapter_resume_train(model, adapter_path):
|
||||||
|
"""Verify training mode returns a trainable PeftModel."""
|
||||||
|
# Test loading for training
|
||||||
|
model_result = peft_module.load_adapter(model, adapter_path, is_train=True)
|
||||||
|
assert isinstance(model_result, PeftModel)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_adapter_train_multiple_disallowed(model, adapter_path):
|
||||||
|
"""Verify multiple adapters are rejected in training mode."""
|
||||||
|
with pytest.raises(ValueError, match="only a single LoRA adapter"):
|
||||||
|
peft_module.load_adapter(model, [adapter_path, adapter_path], is_train=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_adapter_infer_multiple_merges(model, adapter_path):
|
||||||
|
"""Verify multiple adapters are merged in inference mode."""
|
||||||
|
# Test merging multiple adapters
|
||||||
|
model_result = peft_module.load_adapter(model, [adapter_path, adapter_path], is_train=False)
|
||||||
|
assert not isinstance(model_result, PeftModel)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_and_export_model(tmp_path, adapter_path):
|
||||||
|
"""Verify merge_and_export_model produces export artifacts."""
|
||||||
|
export_dir = tmp_path / "export"
|
||||||
|
|
||||||
|
args_dict = {
|
||||||
|
"model": TINY_MODEL,
|
||||||
|
"peft_config": {
|
||||||
|
"name": "lora",
|
||||||
|
"adapter_name_or_path": adapter_path,
|
||||||
|
"export_dir": str(export_dir),
|
||||||
|
"export_size": 1,
|
||||||
|
"infer_dtype": "float16",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
merge_and_export_model(args_dict)
|
||||||
|
|
||||||
|
assert export_dir.exists()
|
||||||
|
assert (export_dir / "config.json").exists()
|
||||||
|
assert (export_dir / "model.safetensors").exists()
|
||||||
|
assert (export_dir / "tokenizer_config.json").exists()
|
||||||
51
tests_v1/plugins/model_plugins/test_quantization_plugin.py
Normal file
51
tests_v1/plugins/model_plugins/test_quantization_plugin.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llamafactory.v1.config.model_args import ModelArguments
|
||||||
|
from llamafactory.v1.core.model_engine import ModelEngine
|
||||||
|
|
||||||
|
|
||||||
|
bitsandbytes = pytest.importorskip("bitsandbytes")
|
||||||
|
|
||||||
|
|
||||||
|
def check_quantization_status(model):
|
||||||
|
quantized_info = {"bnb": []}
|
||||||
|
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
# check BitsAndBytes quantization
|
||||||
|
if isinstance(module, bitsandbytes.nn.modules.Linear8bitLt) or isinstance(
|
||||||
|
module, bitsandbytes.nn.modules.Linear4bit
|
||||||
|
):
|
||||||
|
quantized_info["bnb"].append(name)
|
||||||
|
|
||||||
|
return quantized_info
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cuda"])
|
||||||
|
@pytest.mark.parametrize("name, quantization_bit", [("bnb", 4), ("auto", 4)])
|
||||||
|
def test_quantization_plugin(name, quantization_bit):
|
||||||
|
model_args = ModelArguments(
|
||||||
|
model="llamafactory/tiny-random-qwen3",
|
||||||
|
quant_config={
|
||||||
|
"name": name,
|
||||||
|
"quantization_bit": quantization_bit,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model_engine = ModelEngine(model_args=model_args)
|
||||||
|
quantized_info = check_quantization_status(model_engine.model)
|
||||||
|
print(f"Quantized weights for method {name} with {quantization_bit} bit: {quantized_info}")
|
||||||
|
assert any(v for v in quantized_info.values()), "model is not quantized properly."
|
||||||
104
tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py
Normal file
104
tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Unit tests: FSDP2 meta-device loading vs normal loading consistency.
|
||||||
|
|
||||||
|
Validates that the FSDP2 meta loading path behaves correctly for tied weights
|
||||||
|
and non-persistent buffers by comparing it with the standard non-meta path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||||
|
from llamafactory.v1.config.arg_parser import get_args
|
||||||
|
from llamafactory.v1.core.model_engine import ModelEngine
|
||||||
|
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
|
||||||
|
|
||||||
|
|
||||||
|
TINY_MODEL = "llamafactory/tiny-random-qwen3"
|
||||||
|
|
||||||
|
|
||||||
|
def collect_non_persistent_buffers(model):
|
||||||
|
"""Collect all non-persistent buffers from model."""
|
||||||
|
result = {}
|
||||||
|
for mod_name, module in model.named_modules():
|
||||||
|
for buf_name in getattr(module, "_non_persistent_buffers_set", set()):
|
||||||
|
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
|
||||||
|
buf = getattr(module, buf_name, None)
|
||||||
|
if buf is not None:
|
||||||
|
result[fqn] = buf.detach().cpu().clone()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def test_fsdp2_meta_loading_buffers_and_tied_weights():
|
||||||
|
"""Verify non-persistent buffers and tied weights consistency after meta load."""
|
||||||
|
# 1. Initialize DistributedInterface for single process
|
||||||
|
DistributedInterface()
|
||||||
|
|
||||||
|
# 2. Build FSDP2Engine config
|
||||||
|
engine = FSDP2Engine(
|
||||||
|
{
|
||||||
|
"name": "fsdp2",
|
||||||
|
"mixed_precision": "bf16",
|
||||||
|
"reshard_after_forward": True,
|
||||||
|
"offload_params": False,
|
||||||
|
"pin_memory": False,
|
||||||
|
"dcp_path": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(TINY_MODEL)
|
||||||
|
|
||||||
|
# --- NORMAL PATH ---
|
||||||
|
normal_args, *_ = get_args(dict(model=TINY_MODEL, init_config=None))
|
||||||
|
normal_engine = ModelEngine(model_args=normal_args)
|
||||||
|
normal_model = normal_engine.model.to(torch.bfloat16)
|
||||||
|
|
||||||
|
normal_model = engine.shard_model(normal_model)
|
||||||
|
normal_non_persistent = collect_non_persistent_buffers(normal_model)
|
||||||
|
|
||||||
|
del normal_model
|
||||||
|
|
||||||
|
# --- META PATH ---
|
||||||
|
meta_args, *_ = get_args(dict(model=TINY_MODEL, init_config={"name": "init_on_meta"}))
|
||||||
|
meta_model_engine = ModelEngine(model_args=meta_args)
|
||||||
|
meta_model = meta_model_engine.model
|
||||||
|
|
||||||
|
assert meta_model.device.type == "meta", "Model should be on meta device"
|
||||||
|
|
||||||
|
# Process meta device: save buffers -> tie_weights -> load from checkpoint -> restore buffers
|
||||||
|
meta_model = engine.shard_model(meta_model)
|
||||||
|
meta_non_persistent = collect_non_persistent_buffers(meta_model)
|
||||||
|
|
||||||
|
# 3. Tied weights (embed_tokens.weight and lm_head.weight)
|
||||||
|
|
||||||
|
tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
|
||||||
|
if tie_word_embeddings:
|
||||||
|
assert meta_model.lm_head.weight is meta_model.model.embed_tokens.weight, (
|
||||||
|
"Weights should be tied after loading"
|
||||||
|
)
|
||||||
|
|
||||||
|
del meta_model
|
||||||
|
|
||||||
|
# 4. Non-persistent buffers (e.g., inv_freq)
|
||||||
|
normal_buf_keys = set(normal_non_persistent.keys())
|
||||||
|
meta_buf_keys = set(meta_non_persistent.keys())
|
||||||
|
assert normal_buf_keys == meta_buf_keys, "Non-persistent buffer keys mismatch"
|
||||||
|
|
||||||
|
for key in sorted(normal_buf_keys & meta_buf_keys):
|
||||||
|
nb = normal_non_persistent[key]
|
||||||
|
mb = meta_non_persistent[key]
|
||||||
|
assert nb.shape == mb.shape, f"Buffer shape mismatch: {key}"
|
||||||
|
assert torch.allclose(nb.float(), mb.float(), atol=1e-5), f"Buffer value mismatch: {key}"
|
||||||
Reference in New Issue
Block a user