mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-20 23:53:09 +00:00
Compare commits
39 Commits
92fa3df4c4
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
833f6027b1 | ||
|
|
d91d8af89e | ||
|
|
e67ab9e2f2 | ||
|
|
2c4f121817 | ||
|
|
487f8b8191 | ||
|
|
78cad1e332 | ||
|
|
70653026f5 | ||
|
|
246192abd2 | ||
|
|
0258dc14d0 | ||
|
|
3045adf0ba | ||
|
|
a3d44e3152 | ||
|
|
edeb953bc7 | ||
|
|
d045794387 | ||
|
|
9501c3308a | ||
|
|
0ee1c42c2b | ||
|
|
3061f48d55 | ||
|
|
2d9bd2aa14 | ||
|
|
c0245c43fc | ||
|
|
eb976d75a2 | ||
|
|
b5cb7cb0e6 | ||
|
|
0779846513 | ||
|
|
45d335c709 | ||
|
|
816480012f | ||
|
|
d3bf882e87 | ||
|
|
589da21d32 | ||
|
|
122cd46084 | ||
|
|
2b8b871475 | ||
|
|
aab9b400bb | ||
|
|
50599c719b | ||
|
|
a0f3ad0cee | ||
|
|
f80e15dbb4 | ||
|
|
991267fd3b | ||
|
|
5c52afa30d | ||
|
|
675ce8cc7f | ||
|
|
ab073f4c13 | ||
|
|
184304b5b4 | ||
|
|
d3ebd5678d | ||
|
|
1d5e8ebcd0 | ||
|
|
ea644d04ec |
77
.github/workflows/docs.yml
vendored
Normal file
77
.github/workflows/docs.yml
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
name: Build and Deploy Sphinx Docs
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
paths:
|
||||
- "docs/**"
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
paths:
|
||||
- "docs/**"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pages: write
|
||||
id-token: write
|
||||
|
||||
concurrency:
|
||||
group: "pages"
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -r docs/requirements.txt
|
||||
|
||||
- name: Build Sphinx
|
||||
run: |
|
||||
sphinx-build -b html docs/zh docs/_build/html/zh
|
||||
sphinx-build -b html docs/en docs/_build/html/en
|
||||
|
||||
printf '%s\n' \
|
||||
'<!DOCTYPE html>' \
|
||||
'<html>' \
|
||||
' <head>' \
|
||||
' <meta charset="utf-8" />' \
|
||||
' <meta http-equiv="refresh" content="0; url=zh/index.html" />' \
|
||||
' <script>window.location.href="zh/index.html"+window.location.search+window.location.hash;</script>' \
|
||||
' <title>Redirecting...</title>' \
|
||||
' </head>' \
|
||||
' <body>' \
|
||||
' <a href="zh/index.html">Redirecting...</a>' \
|
||||
' </body>' \
|
||||
'</html>' \
|
||||
> docs/_build/html/index.html
|
||||
|
||||
touch docs/_build/html/.nojekyll
|
||||
|
||||
- name: Setup Pages
|
||||
uses: actions/configure-pages@v5
|
||||
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-pages-artifact@v3
|
||||
with:
|
||||
path: docs/_build/html
|
||||
|
||||
deploy:
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
runs-on: ubuntu-latest
|
||||
needs: build
|
||||
steps:
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@v4
|
||||
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -35,15 +35,12 @@ jobs:
|
||||
transformers:
|
||||
- ""
|
||||
include: # test backward compatibility
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.51.0"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.53.0"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.55.0"
|
||||
- python: "3.11"
|
||||
os: "ubuntu-latest"
|
||||
transformers: "4.57.1"
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
|
||||
1
.github/workflows/tests_cuda.yml
vendored
1
.github/workflows/tests_cuda.yml
vendored
@@ -61,6 +61,7 @@ jobs:
|
||||
uv venv
|
||||
uv pip install -e .
|
||||
uv pip install -r requirements/dev.txt
|
||||
uv pip install -r requirements/bitsandbytes.txt
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
|
||||
@@ -291,7 +291,7 @@ Read technical notes:
|
||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
|
||||
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
|
||||
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small |
|
||||
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small|
|
||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
||||
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||
@@ -319,6 +319,7 @@ Read technical notes:
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
||||
@@ -472,7 +473,7 @@ huggingface-cli login
|
||||
|
||||
| Mandatory | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.9 | 3.10 |
|
||||
| python | 3.11 | >=3.11 |
|
||||
| torch | 2.0.0 | 2.6.0 |
|
||||
| torchvision | 0.15.0 | 0.21.0 |
|
||||
| transformers | 4.49.0 | 4.50.0 |
|
||||
|
||||
@@ -293,7 +293,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
|
||||
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
|
||||
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small |
|
||||
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small|
|
||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
|
||||
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||
@@ -321,6 +321,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
||||
@@ -474,7 +475,7 @@ huggingface-cli login
|
||||
|
||||
| 必需项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.9 | 3.10 |
|
||||
| python | 3.11 | >=3.11 |
|
||||
| torch | 2.0.0 | 2.6.0 |
|
||||
| torchvision | 0.15.0 | 0.21.0 |
|
||||
| transformers | 4.49.0 | 4.50.0 |
|
||||
|
||||
@@ -236,6 +236,13 @@
|
||||
"ms_hub_url": "AI-ModelScope/sharegpt_gpt4",
|
||||
"formatting": "sharegpt"
|
||||
},
|
||||
"sgsc_b2b_entities": {
|
||||
"hf_hub_url": "Nooxus-AI/NOO-Verified-Global-Entities",
|
||||
"formatting": "sharegpt",
|
||||
"columns": {
|
||||
"messages": "messages"
|
||||
}
|
||||
},
|
||||
"ultrachat_200k": {
|
||||
"hf_hub_url": "HuggingFaceH4/ultrachat_200k",
|
||||
"ms_hub_url": "AI-ModelScope/ultrachat_200k",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# https://hub.docker.com/r/ascendai/cann/tags
|
||||
|
||||
ARG BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11
|
||||
ARG BASE_IMAGE=quay.io/ascend/cann:8.5.1-910b-ubuntu22.04-py3.11
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
# Installation arguments
|
||||
@@ -33,9 +33,11 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
COPY . /app
|
||||
|
||||
# Install torch-npu
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
||||
pip install --no-cache-dir -e . --no-build-isolation && \
|
||||
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
RUN pip uninstall -y torch torchvision torchaudio
|
||||
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
|
||||
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
|
||||
RUN pip install --no-cache-dir -e . --no-build-isolation && \
|
||||
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||
|
||||
# Set up volumes
|
||||
|
||||
@@ -33,7 +33,7 @@ services:
|
||||
dockerfile: ./docker/docker-npu/Dockerfile
|
||||
context: ../..
|
||||
args:
|
||||
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
|
||||
BASE_IMAGE: quay.io/ascend/cann:8.5.1-a3-ubuntu22.04-py3.11
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
container_name: llamafactory-a3
|
||||
image: llamafactory:npu-a3
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# https://hub.docker.com/r/rocm/pytorch/tags
|
||||
ARG BASE_IMAGE=rocm/pytorch:rocm6.4.1_ubuntu22.04_py3.10_pytorch_release_2.6.0
|
||||
# ROCm 7.2 + PyTorch 2.7.1 (Python 3.12). Keep base image's PyTorch; do not reinstall.
|
||||
ARG BASE_IMAGE=rocm/pytorch:rocm7.2_ubuntu24.04_py3.12_pytorch_release_2.7.1
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
# Installation arguments
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
ARG INSTALL_FLASHATTN=false
|
||||
ARG HTTP_PROXY=""
|
||||
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
|
||||
|
||||
# Define environments
|
||||
ENV MAX_JOBS=16
|
||||
@@ -32,10 +32,9 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Reinstall pytorch rocm and install LLaMA Factory
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \
|
||||
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
|
||||
# Install LLaMA Factory (use base image's PyTorch/ROCm; do not reinstall)
|
||||
RUN pip install --no-cache-dir -e . --pre && \
|
||||
pip install --no-cache-dir -r requirements/deepspeed.txt -r requirements/liger-kernel.txt -r requirements/bitsandbytes.txt
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||
|
||||
20
docs/Makefile
Normal file
20
docs/Makefile
Normal file
@@ -0,0 +1,20 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS =
|
||||
SPHINXBUILD = sphinx-build
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
49
docs/_static/css/lang-switcher.css
vendored
Normal file
49
docs/_static/css/lang-switcher.css
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
.lang-switcher {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.lang-switcher__select {
|
||||
appearance: none;
|
||||
-webkit-appearance: none;
|
||||
-moz-appearance: none;
|
||||
padding: 6px 28px 6px 10px;
|
||||
border-radius: 999px;
|
||||
border: 1px solid rgba(0, 0, 0, 0.18);
|
||||
background-color: #ffffff;
|
||||
color: #333333;
|
||||
font-size: 13px;
|
||||
line-height: 18px;
|
||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.08);
|
||||
cursor: pointer;
|
||||
background-image: linear-gradient(45deg, transparent 50%, #667085 50%),
|
||||
linear-gradient(135deg, #667085 50%, transparent 50%);
|
||||
background-position: calc(100% - 16px) 50%, calc(100% - 11px) 50%;
|
||||
background-size: 5px 5px, 5px 5px;
|
||||
background-repeat: no-repeat;
|
||||
}
|
||||
|
||||
.lang-switcher__select:focus {
|
||||
outline: none;
|
||||
border-color: rgba(41, 128, 185, 0.65);
|
||||
box-shadow: 0 0 0 3px rgba(41, 128, 185, 0.18);
|
||||
}
|
||||
|
||||
.wy-side-nav-search .lang-switcher {
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.wy-side-nav-search .lang-switcher__select {
|
||||
border-color: rgba(255, 255, 255, 0.18);
|
||||
background-color: rgba(255, 255, 255, 0.08);
|
||||
color: #ffffff;
|
||||
box-shadow: none;
|
||||
background-image: linear-gradient(45deg, transparent 50%, rgba(255, 255, 255, 0.75) 50%),
|
||||
linear-gradient(135deg, rgba(255, 255, 255, 0.75) 50%, transparent 50%);
|
||||
}
|
||||
|
||||
.wy-side-nav-search .lang-switcher__select:focus {
|
||||
border-color: rgba(255, 255, 255, 0.45);
|
||||
box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.12);
|
||||
}
|
||||
93
docs/_static/js/switcher.js
vendored
Normal file
93
docs/_static/js/switcher.js
vendored
Normal file
@@ -0,0 +1,93 @@
|
||||
document.addEventListener('DOMContentLoaded', function () {
|
||||
var path = window.location.pathname || '';
|
||||
var isZh = path.indexOf('/zh/') !== -1;
|
||||
var isEn = path.indexOf('/en/') !== -1;
|
||||
if (!isZh && !isEn) return;
|
||||
|
||||
var currentLang = isZh ? 'zh' : 'en';
|
||||
|
||||
function buildSwitcher() {
|
||||
var container = document.createElement('div');
|
||||
container.className = 'lang-switcher';
|
||||
|
||||
var select = document.createElement('select');
|
||||
select.setAttribute('aria-label', 'Language');
|
||||
select.className = 'lang-switcher__select';
|
||||
|
||||
var optionZh = document.createElement('option');
|
||||
optionZh.value = 'zh';
|
||||
optionZh.textContent = 'Simplified Chinese';
|
||||
optionZh.selected = isZh;
|
||||
|
||||
var optionEn = document.createElement('option');
|
||||
optionEn.value = 'en';
|
||||
optionEn.textContent = 'English';
|
||||
optionEn.selected = isEn;
|
||||
|
||||
select.appendChild(optionZh);
|
||||
select.appendChild(optionEn);
|
||||
|
||||
select.addEventListener('change', function () {
|
||||
var nextLang = select.value;
|
||||
if (nextLang === currentLang) return;
|
||||
var targetUrl = path.replace('/' + currentLang + '/', '/' + nextLang + '/');
|
||||
window.location.href = targetUrl + window.location.search + window.location.hash;
|
||||
});
|
||||
|
||||
container.appendChild(select);
|
||||
return container;
|
||||
}
|
||||
|
||||
function hideOtherLanguageToc() {
|
||||
var captions = document.querySelectorAll('p.caption');
|
||||
for (var i = 0; i < captions.length; i++) {
|
||||
var caption = captions[i];
|
||||
var textEl = caption.querySelector('.caption-text');
|
||||
if (!textEl) continue;
|
||||
var label = (textEl.textContent || '').trim().toLowerCase();
|
||||
|
||||
var isCaptionZh = label === '中文' || label === 'chinese' || label === 'zh';
|
||||
var isCaptionEn = label === 'english' || label === 'en';
|
||||
|
||||
if (!isCaptionZh && !isCaptionEn) continue;
|
||||
|
||||
var shouldHide = (currentLang === 'zh' && isCaptionEn) || (currentLang === 'en' && isCaptionZh);
|
||||
var shouldHideCaption = true;
|
||||
|
||||
var next = caption.nextElementSibling;
|
||||
if (next && next.tagName && next.tagName.toLowerCase() === 'ul') {
|
||||
if (shouldHide) {
|
||||
caption.style.display = 'none';
|
||||
next.style.display = 'none';
|
||||
} else if (shouldHideCaption) {
|
||||
caption.style.display = 'none';
|
||||
}
|
||||
} else if (shouldHide) {
|
||||
caption.style.display = 'none';
|
||||
} else if (shouldHideCaption) {
|
||||
caption.style.display = 'none';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var side = document.querySelector('.wy-side-nav-search');
|
||||
if (side) {
|
||||
var sideSwitcher = buildSwitcher();
|
||||
sideSwitcher.style.marginTop = '8px';
|
||||
sideSwitcher.style.display = 'flex';
|
||||
sideSwitcher.style.justifyContent = 'center';
|
||||
side.appendChild(sideSwitcher);
|
||||
} else {
|
||||
var topRight = buildSwitcher();
|
||||
topRight.style.position = 'fixed';
|
||||
topRight.style.top = '12px';
|
||||
topRight.style.right = '12px';
|
||||
topRight.style.zIndex = '9999';
|
||||
document.body.appendChild(topRight);
|
||||
}
|
||||
|
||||
hideOtherLanguageToc();
|
||||
window.addEventListener('load', hideOtherLanguageToc);
|
||||
setTimeout(hideOtherLanguageToc, 50);
|
||||
setTimeout(hideOtherLanguageToc, 300);
|
||||
});
|
||||
35
docs/conf.py
Normal file
35
docs/conf.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
|
||||
|
||||
# Define common settings here
|
||||
project = "LlamaFactory"
|
||||
copyright = "2024, LlamaFactory Team"
|
||||
author = "LlamaFactory Team"
|
||||
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx.ext.napoleon",
|
||||
"myst_parser",
|
||||
]
|
||||
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
html_static_path = ["_static"]
|
||||
|
||||
html_js_files = [
|
||||
"js/switcher.js",
|
||||
]
|
||||
|
||||
html_css_files = [
|
||||
"css/lang-switcher.css",
|
||||
]
|
||||
|
||||
myst_enable_extensions = [
|
||||
"colon_fence",
|
||||
"deflist",
|
||||
]
|
||||
myst_heading_anchors = 3
|
||||
3
docs/en/advanced/custom-kernels/custom-kernels.md
Normal file
3
docs/en/advanced/custom-kernels/custom-kernels.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Custom Kernels
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/advanced/custom-kernels/fused-operators.md
Normal file
3
docs/en/advanced/custom-kernels/fused-operators.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Fused Operators
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/advanced/custom-kernels/triton.md
Normal file
3
docs/en/advanced/custom-kernels/triton.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Triton
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/advanced/distributed/deepspeed.md
Normal file
3
docs/en/advanced/distributed/deepspeed.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# DeepSpeed
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/advanced/distributed/fsdp.md
Normal file
3
docs/en/advanced/distributed/fsdp.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# FSDP
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/advanced/distributed/parallel-dp-tp-ep-sp-cp.md
Normal file
3
docs/en/advanced/distributed/parallel-dp-tp-ep-sp-cp.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Parallel (DP, TP, EP, SP, CP)
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/advanced/lora-and-quantization/lora.md
Normal file
3
docs/en/advanced/lora-and-quantization/lora.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# LoRA
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/advanced/lora-and-quantization/quantization.md
Normal file
3
docs/en/advanced/lora-and-quantization/quantization.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Quantization
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
22
docs/en/conf.py
Normal file
22
docs/en/conf.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
# Add parent dir to path to allow importing conf.py
|
||||
sys.path.insert(0, os.path.abspath(".."))
|
||||
|
||||
from conf import * # noqa: F403
|
||||
|
||||
|
||||
# Language settings
|
||||
language = "en"
|
||||
html_search_language = "en"
|
||||
|
||||
# Static files
|
||||
# Point to the root _static directory
|
||||
html_static_path = ["../_static"]
|
||||
|
||||
# Add custom JS for language switcher
|
||||
html_js_files = [
|
||||
"js/switcher.js",
|
||||
]
|
||||
3
docs/en/data-preparation/data-processing.md
Normal file
3
docs/en/data-preparation/data-processing.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Data Processing
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/dev-guide/core/data-engine.md
Normal file
3
docs/en/dev-guide/core/data-engine.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# DataEngine
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/dev-guide/core/model-engine.md
Normal file
3
docs/en/dev-guide/core/model-engine.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# ModelEngine
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/dev-guide/core/trainer.md
Normal file
3
docs/en/dev-guide/core/trainer.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Trainer
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/dev-guide/plugins/data-plugins.md
Normal file
3
docs/en/dev-guide/plugins/data-plugins.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Data Plugins
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
@@ -0,0 +1,3 @@
|
||||
# Initialization
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/dev-guide/plugins/model-plugins/kernels.md
Normal file
3
docs/en/dev-guide/plugins/model-plugins/kernels.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Kernels
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/dev-guide/plugins/model-plugins/rendering.md
Normal file
3
docs/en/dev-guide/plugins/model-plugins/rendering.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Rendering
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/getting-started.md
Normal file
3
docs/en/getting-started.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Getting Started
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/hyperparameters/data-argument.md
Normal file
3
docs/en/hyperparameters/data-argument.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Data Argument
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/hyperparameters/model-argument.md
Normal file
3
docs/en/hyperparameters/model-argument.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Model Argument
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/hyperparameters/sample-argument.md
Normal file
3
docs/en/hyperparameters/sample-argument.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Sample Argument
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/hyperparameters/training-argument.md
Normal file
3
docs/en/hyperparameters/training-argument.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Training Argument
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
62
docs/en/index.rst
Normal file
62
docs/en/index.rst
Normal file
@@ -0,0 +1,62 @@
|
||||
LlamaFactory Docs
|
||||
=================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Getting Started
|
||||
|
||||
getting-started
|
||||
installation
|
||||
llamaboard-web-ui
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Data Preparation
|
||||
|
||||
data-preparation/data-processing
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Training
|
||||
|
||||
training/sft
|
||||
training/dpo
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Inference
|
||||
|
||||
inference/deploy
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Advanced
|
||||
|
||||
advanced/lora-and-quantization/lora
|
||||
advanced/lora-and-quantization/quantization
|
||||
advanced/distributed/fsdp
|
||||
advanced/distributed/deepspeed
|
||||
advanced/distributed/parallel-dp-tp-ep-sp-cp
|
||||
advanced/custom-kernels/triton
|
||||
advanced/custom-kernels/fused-operators
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Hyperparameters
|
||||
|
||||
hyperparameters/data-argument
|
||||
hyperparameters/model-argument
|
||||
hyperparameters/sample-argument
|
||||
hyperparameters/training-argument
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Dev Guide
|
||||
|
||||
dev-guide/core/data-engine
|
||||
dev-guide/core/model-engine
|
||||
dev-guide/core/trainer
|
||||
dev-guide/plugins/data-plugins
|
||||
dev-guide/plugins/model-plugins/initialization
|
||||
dev-guide/plugins/model-plugins/kernels
|
||||
dev-guide/plugins/model-plugins/rendering
|
||||
3
docs/en/inference/deploy.md
Normal file
3
docs/en/inference/deploy.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Deploy
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/installation.md
Normal file
3
docs/en/installation.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Installation
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/llamaboard-web-ui.md
Normal file
3
docs/en/llamaboard-web-ui.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# LlamaBoard Web UI
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/training/dpo.md
Normal file
3
docs/en/training/dpo.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# DPO
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
3
docs/en/training/sft.md
Normal file
3
docs/en/training/sft.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# SFT
|
||||
|
||||
This page is not yet available in English. Use the language switcher to view Simplified Chinese.
|
||||
35
docs/make.bat
Normal file
35
docs/make.bat
Normal file
@@ -0,0 +1,35 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=.
|
||||
set BUILDDIR=_build
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to your PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.http://sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
||||
3
docs/requirements.txt
Normal file
3
docs/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
sphinx>=6.0.0
|
||||
sphinx-rtd-theme>=1.2.0
|
||||
myst-parser>=2.0.0
|
||||
93
docs/zh/advanced/custom-kernels/custom-kernels.md
Normal file
93
docs/zh/advanced/custom-kernels/custom-kernels.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# LLaMA-Factory Kernels 系统
|
||||
|
||||
## 概述
|
||||
|
||||
LLaMA-Factory Kernels 系统用于管理不同硬件设备提供的高性能计算内核(kernel)实现,该系统通过替换模型中的关键模块(如 RMSNorm、SwiGLU、RoPE、MoE 等)为硬件优化的版本,从而显著提升模型训练和推理的性能。
|
||||
|
||||
Kernels 系统采用基于注册表的自动发现机制,能够根据当前运行环境自动检测可用的硬件设备(NPU、CUDA 等),并使能相应的高性能 kernels。这种设计使得用户无需关心底层实现细节,只需简单调用接口即可获得性能提升。
|
||||
|
||||
## 核心特性
|
||||
|
||||
- **自动注册机制**:基于 `@register_kernel` 装饰器实现自动注册系统。系统启动时会自动扫描 `ops` 目录下的 kernel 实现,并将其注册到全局注册表中。
|
||||
|
||||
- **设备适配感知**:自动检测当前硬件设备(NPU、CUDA 等)并应用相应的优化。系统会跳过不支持的设备,确保在不同环境下都能正常工作。
|
||||
|
||||
- **模块化设计**:每个 kernel 独立实现,互不干扰。可以单独应用某个 kernel,也可以批量应用所有默认的 kernels。
|
||||
|
||||
- **后向兼容**:kernel 替换不修改模型权重,保持数值一致性。优化后的实现与原始实现保持精度一致(在浮点误差范围内)。
|
||||
|
||||
- **灵活扩展**:通过继承 `BaseKernel` 基类并使用装饰器,可以轻松添加新的 kernel 实现,支持新的硬件设备或优化算法。
|
||||
|
||||
## 使用方式
|
||||
|
||||
### 1. 通过训练 YAML 配置文件使用
|
||||
|
||||
要在训练过程中使能 kernels,只需在配置文件中增加如下配置,即可自动使能所有默认可用的 kernels:
|
||||
|
||||
```yaml
|
||||
...
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
### 2. 调用 API 使能
|
||||
|
||||
#### 2.1 apply_default_kernels 使能所有默认 kernels
|
||||
|
||||
`apply_default_kernels` API 能够自动应用当前设备上所有默认注册的 kernels:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
from llamafactory.v1.plugins.model_plugins.kernels import apply_default_kernels
|
||||
|
||||
# 加载模型
|
||||
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
||||
|
||||
# 自动应用所有默认 kernels
|
||||
model = apply_default_kernels(model, include_kernels="auto")
|
||||
```
|
||||
|
||||
#### 2.2 apply_kernel 使能特定 kernel
|
||||
|
||||
如果需要更精细的控制,例如在某些场合单独应用某个 kernel,可以手动调用 `apply_kernel` 函数并传入 kernel ID:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
from llamafactory.v1.plugins.model_plugins.kernels import apply_kernel
|
||||
|
||||
# 加载模型
|
||||
model = AutoModelForCausalLM.from_pretrained("qwen/qwen2.5-0.5B")
|
||||
|
||||
# 手动应用各个 kernels
|
||||
# 注意:kernel ID 必须与定义时的 _kernel_id 一致
|
||||
model = apply_kernel("npu_fused_rope", model=model)
|
||||
model = apply_kernel("npu_fused_rmsnorm", model=model)
|
||||
model = apply_kernel("npu_fused_swiglu", model=model)
|
||||
model = apply_kernel("npu_fused_moe", model=model)
|
||||
|
||||
### 3. 查询已注册的可用 kernels
|
||||
|
||||
可以通过 `get_default_kernels` 获取当前环境中所有已注册且可用的默认 kernel ID:
|
||||
|
||||
```python
|
||||
from llamafactory.v1.plugins.model_plugins.kernels import get_default_kernels
|
||||
|
||||
# 获取默认 kernel 列表
|
||||
available_kernels = get_default_kernels()
|
||||
print(f"Available kernels: {available_kernels}")
|
||||
# 输出示例: ['npu_fused_rmsnorm', 'npu_fused_swiglu', 'npu_fused_rope', 'npu_fused_moe']
|
||||
```
|
||||
|
||||
### 当前已实现的 kernels
|
||||
|
||||
| Kernel ID | 功能 | 支持的设备 | 备注 |
|
||||
|-----------|------|-----------|------|
|
||||
| [npu_fused_rmsnorm](./fused-operators.md/#npufusedrmsnorm) | RMSNorm 融合算子 | NPU | NPU 设备的高性能 RMSNorm 实现 |
|
||||
| [npu_fused_swiglu](./fused-operators.md/#npufusedswiglu) | SwiGLU 融合算子 | NPU | NPU 设备的高性能 SwiGLU 实现 |
|
||||
| [npu_fused_rope](./fused-operators.md/#npufusedrope) | RoPE 融合算子 | NPU | NPU 设备的高性能 RoPE 实现 |
|
||||
| [npu_fused_moe](./fused-operators.md/#npufusedmoe) | MoE 融合算子 | NPU | MoE 融合算子,适配 Qwen3-MoE 等模型 |
|
||||
|
||||
我们会持续适配更多的 kernels,如果您需要自己开发新的 kernels,请参考我们的 [Kernel 开发文档](../../dev-guide/plugins/model-plugins/kernels.md),欢迎您向 LLaMA-Factory 贡献代码。
|
||||
104
docs/zh/advanced/custom-kernels/fused-operators.md
Normal file
104
docs/zh/advanced/custom-kernels/fused-operators.md
Normal file
@@ -0,0 +1,104 @@
|
||||
# Fused Operators
|
||||
|
||||
LLaMA-Factory 提供了一系列针对特定硬件优化的融合算子。这些算子位于 `src/llamafactory/v1/plugins/model_plugins/kernels/ops` 目录下。
|
||||
|
||||
系统启动时,`scan_all_kernels` 函数会自动扫描该目录,注册所有可用的算子。您可以通过 `apply_default_kernels(model, include_kernels="auto")` 一键启用它们,或者使用 `apply_kernel` 单独启用。
|
||||
|
||||
以下是当前支持的融合算子详情:
|
||||
|
||||
## NpuFusedRMSNorm
|
||||
RMSNorm(Root Mean Square Layer Normalization)是一种常用于大模型的归一化方法。在推理或训练中,RMSNorm 融合算子 将bias、residual等操作进行融合,可以减少显存访问次数,加速计算。
|
||||
|
||||
Ascend npu 通过 `torch_npu.npu_rms_norm` 接口提供 RMSNorm 融合算子调用接口,支持 float16, bfloat16, float 等数据格式。RMSNorm 算子常见于Qwen等LLM模型中,由于torch侧没有提供 RMSNorm 算子的接口,因此在模型中通常是以自定义类的形式出现,通过替换 RMSNorm 类的 `forward` 方法即可使能。
|
||||
|
||||
```python
|
||||
def _npu_rms_forward(self, hidden_states):
|
||||
"""NPU forward implementation for RMSNorm.
|
||||
|
||||
Args:
|
||||
self: RMSNorm module instance with `weight` and `variance_epsilon`.
|
||||
hidden_states: Input hidden states tensor, same shape as the baseline.
|
||||
|
||||
Returns:
|
||||
Normalized tensor consistent with the baseline RMSNorm behavior.
|
||||
"""
|
||||
|
||||
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
|
||||
```
|
||||
|
||||
在 LlamaFactory 中,通过 `NpuRMSNormKernel` 提供使能该融合算子的入口,只需要调用 `apply_kernel("npu_fused_rmsnorm", model=model)` 即可针对已适配的模型使能 npu RMSNorm 融合算子。
|
||||
|
||||
## NpuFusedSwiGlu
|
||||
SwiGLU(Swish-Gated Linear Unit)是一种结合了Swish激活函数和门控线性单元(GLU)的混合激活函数,其主要功能是对输入张量进行门控线性变换,近年来被广泛应用于 LLM 模型中的 MLP 层。SwiGLU 融合算子将分割、激活、矩阵乘等多个操作融合为单一硬件指令,避免多次内核启动开销。
|
||||
|
||||
Ascend npu 通过 `torch_npu.npu_swiglu` 接口提供 SwiGLU 融合算子调用接口,支持 float16,bfloat16,float SwiGLU 算子常见于Qwen等LLM模型中,由于torch侧没有提供 SwiGLU 算子的接口,因此在模型中通常是以自定义类的形式出现,通过替换 SwiGLU 类的 `forward` 方法即可使能。替换过程可参考如下示例:
|
||||
|
||||
```python
|
||||
# 原始 MLP forward 方法:
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
# 替换后的 forward 方法:
|
||||
def _npu_swiglu_forward(self, hidden_state):
|
||||
return self.down_proj(
|
||||
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
|
||||
)
|
||||
```
|
||||
|
||||
在 LLaMA-Factory 中,通过 `NpuSwiGluKernel` 提供使能该融合算子的入口,只需要调用 `apply_kernel("npu_fused_swiglu", model=model)` 即可针对已适配的模型使能 npu SwiGLU 融合算子。对于未适配的模型,如有需要,您可根据示例以及[开发者文档](../../dev-guide/plugins/model-plugins/kernels.md)自行适配。
|
||||
|
||||
|
||||
## NpuFusedRoPE
|
||||
RoPE(Rotary Positional Embedding,旋转式位置嵌入) 是一种位置编码技术,广泛应用于 Qwen 等 LLM 模型中,用于有效编码文本序列的位置信息。它结合了绝对位置编码的稳定性与相对位置编码的灵活性,同时具备优秀的长度泛化能力。传统 RoPE 算子通常在 LLM 等模型结构中通过自定义函数的形式实现。RoPE 融合算子将原计算流程合并为单个硬件优化算子,从而提升性能。
|
||||
|
||||
Ascend npu 通过 `torch_npu.npu_rotary_mul` 提供 RoPE 融合算子调用接口,支持 float16,bfloat16,float32 等数据格式。以 Qwen3 系列模型为例,通过替换其 `apply_rotary_pos_emb` 函数即可实现 RoPE融合算子使能:
|
||||
|
||||
```python
|
||||
# 原始 apply_rotary_pos_emb:
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
# 替换 RoPE 融合算子后:
|
||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
||||
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
||||
return q_embed, k_embed
|
||||
```
|
||||
|
||||
在 LLaMA-Factory 中,通过 `NpuRoPEKernel` 提供使能该融合算子的入口,只需要调用 `apply_kernel("npu_fused_rope", model=model)` 即可针对已适配的模型使能 npu RoPE 融合算子。对于未适配的模型,如有需要,您可根据示例以及[开发者文档](../../dev-guide/plugins/model-plugins/kernels.md)自行适配。
|
||||
|
||||
|
||||
## NpuFusedMoE
|
||||
MoE(Mixture of Experts)模型通过稀疏激活扩展容量。在原生 Transformers 实现中,使用串行循环遍历专家,导致内核启动开销大、硬件利用率低。
|
||||
|
||||
**MoE 融合算子** 利用 **GMM(Grouped Matrix Multiplication,分组矩阵乘)** 技术,支持在单个硬件指令内并行处理多组不同形状(行数不一)的矩阵乘法,消减循环开销,同时无需额外的显存复制,显著提升训练性能。
|
||||
|
||||
Ascend npu 通过 `torch_npu.npu_grouped_matmul` 等接口提供底层支持,通过替换模型中的 MoE Block forward 方法,即可利用 NPU 的分组矩阵乘能力。
|
||||
|
||||
核心逻辑替换如下(简化示意):
|
||||
|
||||
```python
|
||||
def _npu_moe_forward(self, hidden_states, routing_weights, router_indices):
|
||||
# 1. 排序:将乱序的 Token 按指派的专家归类,并生成索引映射
|
||||
permuted_states, row_map = torch_npu.npu_moe_token_permute(hidden_states, router_indices)
|
||||
|
||||
# 2. 统计:计算每个专家需要处理的 Token 数量
|
||||
tokens_per_expert = torch.histc(router_indices, bins=self.num_experts, min=0, max=self.num_experts)
|
||||
|
||||
# 3. 计算 (GMM):一次性并行计算所有专家的权重,自动适配不同专家的输入长度
|
||||
inter_states = torch_npu.npu_grouped_matmul(permuted_states, self.gate_up_proj_weights, split_sizes=tokens_per_expert, ...)
|
||||
inter_states = torch_npu.npu_swiglu(inter_states)
|
||||
output = torch_npu.npu_grouped_matmul(inter_states, self.down_proj_weights, split_sizes=tokens_per_expert, ...)
|
||||
|
||||
# 4. 还原:将结果恢复成原始 Token 顺序并应用路由权重
|
||||
return torch_npu.npu_moe_token_unpermute(output, row_map, routing_weights)
|
||||
```
|
||||
|
||||
在 LLaMA-Factory 中,通过 `NpuFusedMoEKernel` 提供使能该融合算子的入口。只需要调用 `apply_kernel("npu_fused_moe", model=model)` 即可针对已适配的模型使能 NPU MoE 融合算子。对于未适配的模型,您也可以参考上述示例代码以及[开发者文档](../../dev-guide/plugins/model-plugins/kernels.md)自行适配。
|
||||
1
docs/zh/advanced/custom-kernels/triton.md
Normal file
1
docs/zh/advanced/custom-kernels/triton.md
Normal file
@@ -0,0 +1 @@
|
||||
# Triton
|
||||
1
docs/zh/advanced/distributed/deepspeed.md
Normal file
1
docs/zh/advanced/distributed/deepspeed.md
Normal file
@@ -0,0 +1 @@
|
||||
# DeepSpeed
|
||||
1
docs/zh/advanced/distributed/fsdp.md
Normal file
1
docs/zh/advanced/distributed/fsdp.md
Normal file
@@ -0,0 +1 @@
|
||||
# FSDP
|
||||
1
docs/zh/advanced/distributed/parallel-dp-tp-ep-sp-cp.md
Normal file
1
docs/zh/advanced/distributed/parallel-dp-tp-ep-sp-cp.md
Normal file
@@ -0,0 +1 @@
|
||||
# Parallel(DP, TP, EP, SP, CP)
|
||||
3
docs/zh/advanced/lora-and-quantization/lora.md
Normal file
3
docs/zh/advanced/lora-and-quantization/lora.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Lora
|
||||
|
||||
参数管理(二级参数形式)
|
||||
1
docs/zh/advanced/lora-and-quantization/quantization.md
Normal file
1
docs/zh/advanced/lora-and-quantization/quantization.md
Normal file
@@ -0,0 +1 @@
|
||||
# Quantization
|
||||
22
docs/zh/conf.py
Normal file
22
docs/zh/conf.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
# Add parent dir to path to allow importing conf.py
|
||||
sys.path.insert(0, os.path.abspath(".."))
|
||||
|
||||
from conf import * # noqa: F403
|
||||
|
||||
|
||||
# Language settings
|
||||
language = "zh_CN"
|
||||
html_search_language = "zh"
|
||||
|
||||
# Static files
|
||||
# Point to the root _static directory
|
||||
html_static_path = ["../_static"]
|
||||
|
||||
# Add custom JS for language switcher
|
||||
html_js_files = [
|
||||
"js/switcher.js",
|
||||
]
|
||||
479
docs/zh/data-preparation/data-processing.md
Normal file
479
docs/zh/data-preparation/data-processing.md
Normal file
@@ -0,0 +1,479 @@
|
||||
# LLaMA-Factory v1 数据预处理
|
||||
|
||||
## 总览
|
||||
|
||||
LLaMA-Factory `v1` 采用了全新的数据处理架构,主要包含以下核心组件:
|
||||
|
||||
- **DataEngine**:数据引擎,负责数据集的加载、索引和转换等各种插件的接入和调用,并提供数据访问接口
|
||||
- **DataConverterPlugin**:数据转换器,将非标准格式转换为统一的标准格式
|
||||
- **DataLoaderPlugin**:数据加载插件,支持多种文件格式的加载
|
||||
- **DataIndexPlugin**:数据索引插件,支持数据集的采样和权重调整
|
||||
- **DataSelectorPlugin**:数据选择插件,支持灵活的数据访问方式
|
||||
|
||||
与 LLaMA-Factory `v0` 版本相比,`v1` 版本采用了统一的数据格式(Messages Format),所有数据都会被转换为标准的对话消息列表;此外,`v1` 版本通过 DataEngine 与 Plugin 机制,提供了自定义数据处理流的接口,具有更好的可扩展性和一致性。
|
||||
|
||||
---
|
||||
|
||||
## 目录
|
||||
|
||||
- [基本用法](#基本用法)
|
||||
- [标准数据格式](#标准数据格式)
|
||||
- [数据集配置文件](#数据集配置文件)
|
||||
- [完整示例](#完整示例)
|
||||
|
||||
---
|
||||
|
||||
## 基本用法
|
||||
|
||||
### 在训练配置文件,可以通过如下方式配置数据集:
|
||||
|
||||
<details open>
|
||||
<summary>方式 1:使用 HF Hub Repo ID</summary>
|
||||
|
||||
直接指定 HF Hub 上的数据集 Repo ID,DataEngine 会自动从 HF Hub 下载并加载数据集。
|
||||
|
||||
注:使用 Repo ID 直接加载的数据集需要为标准格式
|
||||
|
||||
**训练配置文件示例:**
|
||||
|
||||
```yaml
|
||||
# example_sft.yaml
|
||||
|
||||
...
|
||||
|
||||
dataset: llamafactory/v1-sft-demo # HF Hub Repo ID
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>方式 2:使用 HF Hub 上的 YAML 配置文件</summary>
|
||||
|
||||
`dataset`字段指定 HF Hub 上的 `dataset_info.yaml` 的 URI,DataEngine 会自动下载该配置文件并根据其中的配置加载数据集。
|
||||
|
||||
**训练配置文件示例:**
|
||||
|
||||
```yaml
|
||||
# example_sft.yaml
|
||||
|
||||
...
|
||||
|
||||
dataset: llamafactory/v1-sft-demo/dataset_info.yaml # 远程 dataset_info.yaml 路径
|
||||
|
||||
...
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>方式 3:使用本地 HF 数据集文件路径</summary>
|
||||
|
||||
`dataset`字段指定本地的数据集文件路径(`.json`、`.jsonl` 等)
|
||||
|
||||
注:直接指定数据集文件路径,要求该数据文件的格式已为标准格式
|
||||
|
||||
**训练配置文件示例:**
|
||||
|
||||
```yaml
|
||||
# example_sft.yaml
|
||||
|
||||
...
|
||||
|
||||
dataset: ~/data/v1_sft_demo.jsonl # 本地数据集文件绝对路径
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>方式 4:使用本地 YAML 配置文件路径</summary>
|
||||
|
||||
`dataset`字段指定本地的 `dataset_info.yaml` 配置文件路径,DataEngine 会根据该配置加载其中的数据集。
|
||||
|
||||
**训练配置文件示例:**
|
||||
|
||||
```yaml
|
||||
# example_sft.yaml
|
||||
|
||||
...
|
||||
|
||||
dataset: ~/data/dataset_info.yaml # 本地 dataset_info.yaml 文件路径
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
---
|
||||
|
||||
|
||||
|
||||
## 标准数据格式
|
||||
|
||||
v1 使用统一的 **Messages 格式**作为标准数据格式。每个样本都是一个包含 `messages` 字段的 JSON 对象。
|
||||
|
||||
针对alpaca、sharegpt、以及dpo等格式的数据,可以通过内置的`DataConverterPlugin`插件,自动将其转化为标准格式,对于其他自定义格式的数据,用户也可通过自定义`DataConverterPlugin`来实现数据格式标准化,这部分内容参见[`DataConverterPlugin`](../dev-guide/plugins/data-plugins.md/#data-converter-plugin)
|
||||
|
||||
### 1. SFT(监督微调)样本格式
|
||||
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "value": "You are a helpful assistant."}],
|
||||
"loss_weight": 0.0
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": "Hello, who are you?"}],
|
||||
"loss_weight": 0.0
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": "I am an AI assistant."}],
|
||||
"loss_weight": 1.0
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### 字段说明:
|
||||
|
||||
- **messages**: 消息列表,包含一轮或多轮对话
|
||||
- **role**: 消息角色,可选值:
|
||||
- `"system"`: 系统提示
|
||||
- `"user"`: 用户输入
|
||||
- `"assistant"`: 模型回复
|
||||
- **content**: 内容列表,每个元素包含:
|
||||
- **type**: 内容类型,可选值:
|
||||
- `"text"`: 文本内容
|
||||
- `"image_url"`: 图像 URL(多模态)
|
||||
- `"audio_url"`: 音频 URL(多模态)
|
||||
- `"video_url"`: 视频 URL(多模态)
|
||||
- `"tools"`: 工具描述
|
||||
- `"tool_calls"`: 工具调用
|
||||
- `"reasoning"`: 推理过程
|
||||
- **value**: 具体内容(字符串)
|
||||
- **loss_weight**: 损失权重(浮点数)
|
||||
- `0.0`: 不计算损失(用于提示词部分)
|
||||
- `1.0`: 完全计算损失(用于回复部分)
|
||||
- 可设置为其他值以调整不同部分的学习权重
|
||||
|
||||
- **_dataset_name** (可选): 数据集名称,由 DataEngine 自动添加
|
||||
- **extra_info** (可选): 额外信息字段
|
||||
|
||||
### 2. DPO(偏好对齐)样本格式
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen_messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": "用户提问"}],
|
||||
"loss_weight": 0.0
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": "更优的回答"}],
|
||||
"loss_weight": 1.0
|
||||
}
|
||||
],
|
||||
"rejected_messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": "用户提问"}],
|
||||
"loss_weight": 0.0
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": "较差的回答"}],
|
||||
"loss_weight": 1.0
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 多模态支持
|
||||
|
||||
对于多模态数据,可以在 `content` 列表中添加非文本类型的内容:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "value": "这张图片里有什么?"},
|
||||
{"type": "image_url", "value": "path/to/image.jpg"}
|
||||
],
|
||||
"loss_weight": 0.0
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": "图片中有一只猫。"}],
|
||||
"loss_weight": 1.0
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**说明**:`image_url`、`audio_url`、`video_url` 的路径可以是相对路径或绝对路径,具体加载方式由 `DataLoaderPlugin` 决定。
|
||||
|
||||
---
|
||||
|
||||
## 数据集配置文件
|
||||
|
||||
### 1. dataset_info.yaml 配置文件格式
|
||||
|
||||
`dataset_info.yaml` 支持同时配置多个数据集,支持分别从 HF Hub 和本地获取数据集,数据集默认会混合并打乱顺序。
|
||||
|
||||
**示例配置文件:`data/dataset_info.yaml`**
|
||||
|
||||
```yaml
|
||||
# 数据集 1:使用本地文件 + Alpaca 转换器
|
||||
identity:
|
||||
file_name: ~/data/identity.json #本地数据集文件绝对路径
|
||||
converter: alpaca # 使用 alpaca 转换器
|
||||
|
||||
# 数据集 2:指定自定义数据集目录
|
||||
alpaca_en_demo:
|
||||
file_name: ~/data/alpaca_en_demo.json # 数据集文件名
|
||||
converter: alpaca # 转换器插件
|
||||
size: 500 # 只使用 500 个样本
|
||||
weight: 0.5 # 数据集权重,用于控制该数据集的采样频率
|
||||
split: train # 数据集划分,默认为 train
|
||||
streaming: false # 是否流式加载,默认为 false
|
||||
|
||||
# 数据集 3:从 Hugging Face Hub 加载
|
||||
hf_dataset:
|
||||
hf_hub_url: llamafactory/v1-sft-demo # HF repo ID
|
||||
streaming: false
|
||||
|
||||
# 数据集 4:已经是标准格式,无需转换器
|
||||
standard:
|
||||
file_name: ~/data/v1_sft_demo.jsonl # 本地标准数据集文件路径
|
||||
|
||||
# 数据集 5:自定义数据集和 converter 插件
|
||||
custom_dataset:
|
||||
file_name: custom_data.json
|
||||
converter: custom_converter
|
||||
weight: 1.0
|
||||
```
|
||||
|
||||
### 2. 配置字段说明
|
||||
|
||||
#### 数据源配置(二者必选其一):
|
||||
|
||||
- **hf_hub_url** (str): Hugging Face Hub 数据集仓库 ID
|
||||
- 示例:`"llamafactory/v1-sft-demo"`
|
||||
- 如果指定,则从 HF Hub 加载数据集
|
||||
|
||||
- **file_name** (str): 本地文件路径
|
||||
- 支持格式:`.json`、`.jsonl`、`.csv`、`.parquet`、`.arrow`、`.txt`
|
||||
|
||||
#### 可选配置:
|
||||
|
||||
- **split** (str): 数据集划分,默认为 `"train"`
|
||||
- **converter** (str): 数据转换器名称
|
||||
- 可选值:`"alpaca"`(更多转换器持续添加中,也可在 data_plugin 中添加自定义 converter)
|
||||
- 如果不指定,则假定数据已是标准格式
|
||||
- **size** (int): 使用的样本数量,默认使用全部
|
||||
- **weight** (float): 数据集权重,用于混合数据集时的采样频率,默认为 1.0
|
||||
- **streaming** (bool): 是否流式加载,默认为 `False`
|
||||
|
||||
---
|
||||
|
||||
|
||||
## 完整示例
|
||||
|
||||
### 1. 基础使用示例
|
||||
|
||||
```python
|
||||
from llamafactory.v1.config.data_args import DataArguments
|
||||
from llamafactory.v1.core.data_engine import DataEngine
|
||||
|
||||
# 使用本地 YAML 配置
|
||||
data_args = DataArguments(
|
||||
dataset="~/data/v1_sft_demo.jsonl",
|
||||
cutoff_len=2048
|
||||
)
|
||||
|
||||
# 初始化 DataEngine
|
||||
engine = DataEngine(data_args=data_args)
|
||||
|
||||
# 查看数据集信息
|
||||
print(f"数据集总样本数: {len(engine)}")
|
||||
print(f"数据集列表: {list(engine.datasets.keys())}")
|
||||
|
||||
# 访问数据样本
|
||||
sample = engine[0]
|
||||
print(f"样本格式: {sample.keys()}")
|
||||
print(f"消息列表: {sample['messages']}")
|
||||
|
||||
# 批量访问
|
||||
batch = engine[0:10]
|
||||
print(f"批量样本数: {len(batch)}")
|
||||
```
|
||||
|
||||
### 2. 输出示例
|
||||
|
||||
**查看数据集信息输出:**
|
||||
|
||||
```
|
||||
数据集总样本数: 500
|
||||
数据集列表: ['default']
|
||||
样本格式: dict_keys(['_dataset_name', 'messages'])
|
||||
消息列表: [{'role': 'user', 'content': [{'type': 'text', 'value': 'hi'}], 'loss_weight': 0.0}, {'role': 'assistant', 'content': [{'type': 'text', 'value': 'Hello! I am {{name}}, an AI assistant developed by {{author}}. How can I assist you today?'}], 'loss_weight': 1.0}]
|
||||
批量样本数: 10
|
||||
```
|
||||
|
||||
**访问单个样本输出:**
|
||||
|
||||
```python
|
||||
{
|
||||
'_dataset_name': 'alpaca_en_demo',
|
||||
'messages': [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [{'type': 'text', 'value': 'What is the capital of France?'}],
|
||||
'loss_weight': 0.0
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': [{'type': 'text', 'value': 'The capital of France is Paris.'}],
|
||||
'loss_weight': 1.0
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 混合多数据集配置文件示例
|
||||
|
||||
**配置文件:`data/mixed_datasets.yaml`**
|
||||
|
||||
```yaml
|
||||
dataset_1:
|
||||
file_name: alpaca_en_demo.json
|
||||
converter: alpaca
|
||||
weight: 1.0
|
||||
|
||||
dataset_2:
|
||||
file_name: identity.json
|
||||
converter: alpaca
|
||||
weight: 2.0
|
||||
|
||||
dataset_3:
|
||||
hf_hub_url: llamafactory/v1-sft-demo
|
||||
weight: 1.5
|
||||
```
|
||||
|
||||
|
||||
### 4. 多模态数据示例
|
||||
|
||||
**数据文件:`data/multimodal_demo.jsonl`**
|
||||
|
||||
标准化后数据示例:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "value": "Who are they?"},
|
||||
{"type": "image_url", "value": "mllm_demo_data/1.jpg"}
|
||||
],
|
||||
"loss_weight": 0.0
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "value": "They're Kane and Gretzka from Bayern Munich."}
|
||||
],
|
||||
"loss_weight": 1.0
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "value": "What are they doing?"},
|
||||
{"type": "image_url", "value": "mllm_demo_data/1.jpg"}
|
||||
],
|
||||
"loss_weight": 0.0
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "value": "They are celebrating on the soccer field."}
|
||||
],
|
||||
"loss_weight": 1.0
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "value": "Who is he?"},
|
||||
{"type": "image_url", "value": "mllm_demo_data/2.jpg"}
|
||||
],
|
||||
"loss_weight": 0.0
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "value": "He's Thomas Muller from Bayern Munich."}
|
||||
],
|
||||
"loss_weight": 1.0
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "value": "Why is he on the ground?"}
|
||||
],
|
||||
"loss_weight": 0.0
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "value": "Because he's sliding on his knees to celebrate."}
|
||||
],
|
||||
"loss_weight": 1.0
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
```python
|
||||
from llamafactory.v1.config.data_args import DataArguments
|
||||
from llamafactory.v1.core.data_engine import DataEngine
|
||||
|
||||
data_args = DataArguments(dataset="data/multimodal_demo.jsonl")
|
||||
engine = DataEngine(data_args=data_args)
|
||||
|
||||
# 访问多模态样本
|
||||
sample = engine[0]
|
||||
print("用户消息内容:")
|
||||
for content_item in sample['messages'][0]['content']:
|
||||
print(f" 类型: {content_item['type']}, 值: {content_item['value']}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**注意事项**:
|
||||
|
||||
1. 所有数据最终都会转换为标准的 Messages 格式
|
||||
2. 通过 `converter` 插件可以支持多种数据格式
|
||||
3. 通过 `weight` 和 `size` 参数可以灵活控制数据分布
|
||||
4. 支持同时使用本地数据集和 HuggingFace Hub 数据集
|
||||
5. 多模态数据通过在 `content` 中添加不同类型的元素来支持
|
||||
6. 更多细节信息请参考我们的 [API REFERENCE](../dev-guide/core/data-engine.md/#data-engine)
|
||||
253
docs/zh/dev-guide/core/data-engine.md
Normal file
253
docs/zh/dev-guide/core/data-engine.md
Normal file
@@ -0,0 +1,253 @@
|
||||
# DataEngine
|
||||
|
||||
## 1. DataEngine 简介
|
||||
|
||||
|
||||
`DataEngine` 是 LLaMA-Factory v1 数据处理的核心类,继承自 PyTorch 的 `Dataset`,负责各种插件的接入,其他功能(如数据格式转换、数据加载等)均通过插件的形式实现并接入 `DataEngine`。
|
||||
|
||||
`DataEngine`接受一个唯一入参:`DataArguments` 实例,所有的元数据集信息均通过该参数配置传入。
|
||||
|
||||
## 2. DataEngine 与 DataArguments 接口定义
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
""" `DataEngine`初始化入参
|
||||
|
||||
args:
|
||||
dataset (str): 数据集路径,远程数据集 repo id / dataset_info.yaml 路径,或本地数据集路径/dataset_info.yaml路径
|
||||
cutoff_len (int): 数据集截止长度,即数据集最大样本采样数量
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DataEngine(Dataset):
|
||||
"""数据引擎(DataEngine)
|
||||
|
||||
`DataEngine` 负责数据集的加载与统一管理,支持:
|
||||
- 从本地路径或 Hugging Face Hub 加载数据
|
||||
- 通过插件机制加载自定义数据
|
||||
- 构建统一的数据索引
|
||||
- 支持流式(streaming)与非流式数据访问
|
||||
|
||||
attr:
|
||||
args (DataArguments): 数据参数配置
|
||||
datasets (dict[str, HFDataset]): 数据集名称到数据对象的映射
|
||||
dataset_infos (dict[str, DatasetInfo]): 数据集名称到元信息的映射
|
||||
data_index (list[tuple[str, int]]): 数据索引列表,每项为 (dataset_name, sample_index)
|
||||
streaming (bool): 是否为流式数据集
|
||||
"""
|
||||
|
||||
def __init__(self, data_args: DataArguments) -> None:
|
||||
"""初始化 `DataEngine`
|
||||
|
||||
初始化时自动执行以下步骤:
|
||||
1. 调用 `get_dataset_info`, 从 `data_args` 读取并解析数据集元信息
|
||||
2. 调用 `load_dataset`,根据配置加载数据集
|
||||
3. 调用 `build_data_index`,构建统一的索引列表
|
||||
|
||||
args:
|
||||
data_args (DataArguments): 数据参数配置对象
|
||||
"""
|
||||
...
|
||||
|
||||
def get_dataset_info(self) -> None:
|
||||
"""从配置文件或远程仓库加载数据集元信息
|
||||
|
||||
根据 `self.args.dataset` 确定数据源,数据源支持如下选项:
|
||||
- 本地 YAML 配置文件路径
|
||||
- Hugging Face Hub 上的 YAML 配置文件路径
|
||||
- 本地数据集文件路径
|
||||
- Hugging Face Hub 数据集 repo id
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def load_dataset(self) -> None:
|
||||
"""根据数据集元信息加载所有数据集
|
||||
|
||||
每个数据集条目可以包含以下字段:
|
||||
- `hf_hub_url`: 使用 `datasets.load_dataset` 加载
|
||||
- 本地数据文件:通过 `DataLoaderPlugin` 插件加载
|
||||
- `streaming`: 是否启用流式模式
|
||||
|
||||
更新:
|
||||
self.datasets (dict): 数据集名称到已加载数据对象的映射
|
||||
self.streaming (bool): 如果任一数据集为流式模式,则设置为 True
|
||||
"""
|
||||
...
|
||||
|
||||
def build_data_index(self) -> None:
|
||||
"""构建统一的数据索引
|
||||
|
||||
为所有数据集创建全局索引列表 `(dataset_name, sample_index)`
|
||||
|
||||
当启用流式模式时,生成固定长度(例如 1000)的占位索引;
|
||||
否则,为每条样本建立索引。
|
||||
|
||||
插件 `DataIndexPlugin` 可根据数据集大小或权重调整索引分布
|
||||
"""
|
||||
...
|
||||
|
||||
def _convert_data_sample(self, raw_sample: dict[str, Any], dataset_name: str) -> Sample:
|
||||
"""将原始样本转换为统一格式
|
||||
|
||||
根据 `dataset_info` 中的 `converter` 字段,调用对应的转换插件,
|
||||
将原始样本标准化为统一的数据结构。
|
||||
|
||||
args:
|
||||
raw_sample (dict[str, Any]): 原始数据样本
|
||||
dataset_name (str): 样本所属的数据集名称
|
||||
|
||||
return:
|
||||
Sample: 转换后的标准化格式样本
|
||||
"""
|
||||
...
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""返回数据集的总样本数
|
||||
|
||||
return:
|
||||
int: 数据集长度
|
||||
如果为流式数据集,返回 `-1`
|
||||
"""
|
||||
...
|
||||
|
||||
def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
|
||||
"""根据索引或选择器获取样本
|
||||
|
||||
args:
|
||||
index (Union[int, Any]): 数据索引,int 或 list[int]
|
||||
|
||||
return:
|
||||
Union[Sample, list[Sample]]: 单个样本或样本列表
|
||||
"""
|
||||
...
|
||||
|
||||
def __iter__(self) -> Iterable:
|
||||
"""返回数据集迭代器
|
||||
|
||||
用于非流式数据集的顺序或随机访问
|
||||
流式模式下需要实现异步加载逻辑
|
||||
|
||||
return:
|
||||
Iterable: 数据集迭代器。
|
||||
"""
|
||||
...
|
||||
|
||||
async def __aiter__(self) -> AsyncIterable:
|
||||
"""返回异步数据集迭代器
|
||||
|
||||
用于流式数据集或异步数据加载场景
|
||||
允许在异步环境中以流的方式读取样本
|
||||
|
||||
return:
|
||||
AsyncIterable: 异步迭代器,按顺序产出样本
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
```
|
||||
|
||||
`DataArguments` 参数说明:
|
||||
|
||||
`dataset`: 数据集路径,支持本地或远程,当传入本地数据集文件路径时,需要满足该数据集为标准格式;否则需要传入 `dataset_info.yaml` 来配置数据集的 `converter` 等元信息,以告知 `DataEngine` 应当如何处理该数据。
|
||||
|
||||
`cutoff_len`: 数据集的截止长度,即该数据集的最大样本数量。
|
||||
|
||||
---
|
||||
|
||||
## 3. DataEngine 核心方法
|
||||
|
||||
### 3.1 `get_dataset_info`:加载数据元信息
|
||||
|
||||
根据 `dataset` 参数加载数据集配置,获取数据位置、数据格式、插件配置等所有数据元信息,在实例化 `DataEngine` 时会自动调用此方法。
|
||||
|
||||
### 3.2 加载数据集:`load_dataset`
|
||||
|
||||
遍历所有数据源,根据不同的数据源加载数据,在实例化 `DataEngine` 时会自动调用此方法。
|
||||
|
||||
```python
|
||||
for key, value in self.dataset_infos.items():
|
||||
split = value.get("split", "train")
|
||||
streaming = value.get("streaming", False)
|
||||
|
||||
if "hf_hub_url" in value:
|
||||
# 从 HF Hub 加载
|
||||
dataset = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
|
||||
else:
|
||||
# 使用 DataLoaderPlugin 加载本地文件
|
||||
dataset = DataLoaderPlugin(args=self.args).auto_load_data(value)
|
||||
|
||||
self.datasets[key] = dataset
|
||||
```
|
||||
|
||||
### 3.3 `build_data_index`:构建数据索引
|
||||
|
||||
为每个数据集创建索引列表 `[(dataset_name, sample_index), ...]`, `DataIndexPlugin`插件在此处被调用,可控制各数据集的采样频率、采样方式等,在实例化`DataEngine`时会自动调用此方法。
|
||||
|
||||
```python
|
||||
for dataset_name, dataset in self.datasets.items():
|
||||
# 创建基础索引
|
||||
data_index = [(dataset_name, idx) for idx in range(len(dataset))]
|
||||
|
||||
# 根据 size 和 weight 调整索引
|
||||
size = self.dataset_infos[dataset_name].get("size")
|
||||
weight = self.dataset_infos[dataset_name].get("weight")
|
||||
if size or weight:
|
||||
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
|
||||
|
||||
self.data_index.extend(data_index)
|
||||
```
|
||||
|
||||
### 3.4 `_convert_data_sample`:数据格式标准化
|
||||
|
||||
将原始数据转换为标准格式,`DataConverterPlugin`插件在此处被调用,具体调用的插件由 `get_dataset_info` 方法获取的 `converter` 信息指定,若 `converter` 为空则假定数据集为标准格式,此方法由`DataEngine`的 `__getitem__` 方法调用。
|
||||
|
||||
```python
|
||||
def _convert_data_sample(self, raw_sample: dict, dataset_name: str) -> Sample:
|
||||
converter = self.dataset_infos[dataset_name].get("converter")
|
||||
if converter is not None:
|
||||
# 使用指定的转换器
|
||||
from ..plugins.data_plugins.converter import get_converter
|
||||
return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
|
||||
else:
|
||||
# 已经是标准格式
|
||||
return {"_dataset_name": dataset_name, **raw_sample}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 初始化
|
||||
|
||||
`DataEngine` 初始化过程只需传入一个构建好的 `DataArguments` 即可,后续可通过该 `DataEngine` 访问数据集中的数据。
|
||||
|
||||
```python
|
||||
from llamafactory.v1.config.data_args import DataArguments
|
||||
from llamafactory.v1.core.data_engine import DataEngine
|
||||
|
||||
# 1. 创建数据参数
|
||||
data_args = DataArguments(
|
||||
dataset="~/data/v1_sft_demo.jsonl",
|
||||
cutoff_len=2048
|
||||
)
|
||||
|
||||
# 2. 初始化 Data Engine
|
||||
data_engine = DataEngine(data_args=data_args)
|
||||
|
||||
# 3. 访问数据
|
||||
sample = data_engine[0] # 获取第一个样本
|
||||
```
|
||||
|
||||
## 5. 数据访问方式
|
||||
|
||||
实例化后的`DataEngine`支持整数索引、列表索引、以及切片等访问方式,其数据读取用法可等价于 Python 列表。
|
||||
|
||||
```python
|
||||
sample = data_engine[0] # 获取第一个样本
|
||||
|
||||
sample = data_engine[0:10] # 获取前 10 个样本
|
||||
|
||||
sample = data_engine[[0, 5, 10]] # 获取指定索引的样本
|
||||
|
||||
```
|
||||
1
docs/zh/dev-guide/core/model-engine.md
Normal file
1
docs/zh/dev-guide/core/model-engine.md
Normal file
@@ -0,0 +1 @@
|
||||
# ModelEngine
|
||||
1
docs/zh/dev-guide/core/trainer.md
Normal file
1
docs/zh/dev-guide/core/trainer.md
Normal file
@@ -0,0 +1 @@
|
||||
# Trainer
|
||||
467
docs/zh/dev-guide/plugins/data-plugins.md
Normal file
467
docs/zh/dev-guide/plugins/data-plugins.md
Normal file
@@ -0,0 +1,467 @@
|
||||
# Data Plugins
|
||||
|
||||
## 1. Data Plugins 简介
|
||||
|
||||
## DataConverterPlugin
|
||||
|
||||
### 1. DataConverterPlugin 简介
|
||||
|
||||
DataConverter 负责将非标准格式的数据集转换为 v1 的标准 Messages 格式。这使得用户可以继续使用现有的数据集(如 Alpaca 格式),而无需手动转换。针对自定义格式的数据集,用户也可以通过构建对应的自定义 DataConverter 插件,来负责其数据格式标准化。
|
||||
|
||||
当前,LLaMA-Factory 已内置了 `Alpaca Converter` 和 `Pair Converter`,这两类数据集可以直接使用对应的 converter 进行标准化,无需自定义转换器。
|
||||
|
||||
|
||||
### 2. Alpaca Converter 详解
|
||||
|
||||
#### 2.1 Alpaca 格式
|
||||
|
||||
Alpaca 格式是一种常见的指令微调数据格式:
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "You are a helpful assistant.",
|
||||
"instruction": "Describe a process of making crepes.",
|
||||
"input": "",
|
||||
"output": "Making crepes is an easy and delicious process..."
|
||||
}
|
||||
```
|
||||
|
||||
#### 2.2 Alpaca Converter 接口定义
|
||||
|
||||
```python
|
||||
class AlpacaSample(TypedDict, total=False):
|
||||
"""Alpaca 格式数据样本结构
|
||||
|
||||
attr:
|
||||
system (str, 可选): 系统提示信息(system prompt),用于设定对话背景或模型行为。
|
||||
instruction (str, 可选): 用户指令(user instruction),通常为任务描述。
|
||||
input (str, 可选): 额外的输入内容(input text),可与 instruction 拼接。
|
||||
output (str, 可选): 模型生成的目标输出(expected response)。
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
"""将 Alpaca 样本转换为 SFT(Supervised Fine-Tuning)标准样本格式
|
||||
|
||||
`alpaca_converter` 将 Alpaca 数据集中一条样本转换为通用的 `SFTSample` 格式
|
||||
该格式用于监督微调(SFT)或多轮对话建模
|
||||
|
||||
转换逻辑:
|
||||
- 若存在 `system` 字段,则生成一条系统消息,loss_weight = 0.0
|
||||
- 若存在 `instruction` 或 `input` 字段,则合并为一条用户消息,loss_weight = 0.0
|
||||
- 若存在 `output` 字段,则生成一条助手机器人回复消息,loss_weight = 1.0
|
||||
|
||||
args:
|
||||
raw_sample (AlpacaSample): 原始 Alpaca 数据样本
|
||||
|
||||
return:
|
||||
SFTSample: 转换后的标准化样本,格式如下:
|
||||
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{"type": "text", "value": "..."}], "loss_weight": 0.0},
|
||||
{"role": "user", "content": [{"type": "text", "value": "..."}], "loss_weight": 0.0},
|
||||
{"role": "assistant", "content": [{"type": "text", "value": "..."}], "loss_weight": 1.0},
|
||||
]
|
||||
}
|
||||
|
||||
example:
|
||||
>>> raw = {"instruction": "请将以下句子翻译成英文:", "input": "你好", "output": "Hello"}
|
||||
>>> alpaca_converter(raw)
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "value": "请将以下句子翻译成英文:你好"}], "loss_weight": 0.0},
|
||||
{"role": "assistant", "content": [{"type": "text", "value": "Hello"}], "loss_weight": 1.0}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
```
|
||||
|
||||
#### 2.3 转换过程
|
||||
|
||||
`alpaca_converter` 函数将 Alpaca 格式转换为标准格式,转换逻辑如下:
|
||||
|
||||
```python
|
||||
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
messages = []
|
||||
|
||||
# 1. 添加系统提示词(如果存在)
|
||||
if "system" in raw_sample:
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "value": raw_sample["system"]}],
|
||||
"loss_weight": 0.0
|
||||
})
|
||||
|
||||
# 2. 添加用户输入(instruction + input)
|
||||
if "instruction" in raw_sample or "input" in raw_sample:
|
||||
user_content = raw_sample.get("instruction", "") + raw_sample.get("input", "")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": user_content}],
|
||||
"loss_weight": 0.0
|
||||
})
|
||||
|
||||
# 3. 添加模型回复
|
||||
if "output" in raw_sample:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": raw_sample["output"]}],
|
||||
"loss_weight": 1.0
|
||||
})
|
||||
|
||||
return {"messages": messages}
|
||||
```
|
||||
|
||||
#### 2.4 转换示例
|
||||
|
||||
**输入(Alpaca 格式):**
|
||||
|
||||
```json
|
||||
{
|
||||
"instruction": "What is the capital of France?",
|
||||
"input": "",
|
||||
"output": "The capital of France is Paris."
|
||||
}
|
||||
```
|
||||
|
||||
**输出(标准格式):**
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": "What is the capital of France?"}],
|
||||
"loss_weight": 0.0
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": "The capital of France is Paris."}],
|
||||
"loss_weight": 1.0
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 自定义转换器
|
||||
|
||||
#### 3.1 创建自定义转换器
|
||||
|
||||
如果用户有自己的数据格式,可以轻松添加自定义转换器将其标准化,实现过程可参考如下示例:
|
||||
|
||||
```python
|
||||
# src/llamafactory/v1/plugins/data_plugins/converter.py
|
||||
|
||||
from typing import TypedDict, NotRequired
|
||||
from ...extras.types import SFTSample
|
||||
|
||||
# 1. 定义输入格式的类型
|
||||
class MyCustomSample(TypedDict, total=False):
|
||||
question: str
|
||||
answer: str
|
||||
context: NotRequired[str]
|
||||
|
||||
# 2. 实现转换逻辑
|
||||
def custom_converter(raw_sample: MyCustomSample) -> SFTSample:
|
||||
messages = []
|
||||
|
||||
# 构建用户消息
|
||||
user_text = raw_sample["question"]
|
||||
if "context" in raw_sample:
|
||||
user_text = f"Context: {raw_sample['context']}\n\nQuestion: {user_text}"
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "value": user_text}],
|
||||
"loss_weight": 0.0
|
||||
})
|
||||
|
||||
# 构建助手消息
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": raw_sample["answer"]}],
|
||||
"loss_weight": 1.0
|
||||
})
|
||||
|
||||
return {"messages": messages}
|
||||
|
||||
# 3. 注册 custom_converter
|
||||
#src/llamafactory/v1/plugins/data_plugins/converter.py: CONVERTERS
|
||||
CONVERTERS = {
|
||||
"alpaca": alpaca_converter,
|
||||
"custom": custom_converter, # 添加自定义转换器
|
||||
}
|
||||
```
|
||||
|
||||
#### 3.2 使用自定义转换器
|
||||
|
||||
在 YAML 配置中指定转换器名称:
|
||||
|
||||
```yaml
|
||||
my_dataset:
|
||||
file_name: custom_data.json
|
||||
converter: custom
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## DataLoaderPlugin
|
||||
|
||||
### 1. DataLoaderPlugin 简介
|
||||
|
||||
`DataLoaderPlugin` 负责从本地文件加载数据集,当前支持如下文件格式:
|
||||
|
||||
- **JSON**: `.json`
|
||||
- **JSONL**: `.jsonl`
|
||||
- **CSV**: `.csv`
|
||||
- **Parquet**: `.parquet`
|
||||
- **Arrow**: `.arrow`
|
||||
- **Text**: `.txt`
|
||||
|
||||
### 2. DataLoaderPlugin 接口定义
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class DataLoaderPlugin:
|
||||
"""数据加载插件(DataLoaderPlugin)
|
||||
|
||||
负责根据数据集信息(`DatasetInfo`)自动加载本地或远程数据集。
|
||||
支持多种文件格式(如 CSV、JSON、Parquet、Text、Arrow),并可选择是否以流式方式加载。
|
||||
|
||||
通常由 `DataEngine` 调用,用于统一封装数据加载逻辑。
|
||||
"""
|
||||
|
||||
args: DataArguments
|
||||
"""数据参数对象,包含数据目录、缓存路径、分片等配置信息。"""
|
||||
|
||||
|
||||
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
|
||||
"""获取数据集文件格式
|
||||
|
||||
根据输入文件路径自动判断应使用的 HuggingFace `load_dataset` 构建器类型。
|
||||
通过文件扩展名推断数据类型,例如 `.csv`、`.jsonl`、`.parquet`、`.txt` 等。
|
||||
|
||||
args:
|
||||
path (str): 数据集文件路径,用于识别文件类型。
|
||||
|
||||
return:
|
||||
Literal["arrow", "csv", "json", "parquet", "text"]:
|
||||
数据构建器名称,用于 `datasets.load_dataset()`。
|
||||
|
||||
example:
|
||||
>>> _get_builder_name("data/train.jsonl")
|
||||
"json"
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
|
||||
"""根据传入的 `dataset_info` 自动选择合适的加载方式
|
||||
|
||||
args:
|
||||
dataset_info (DatasetInfo): 数据集元信息,通常包含:
|
||||
- `file_name`: 数据文件路径
|
||||
- `split`: 数据划分(如 "train"、"test");
|
||||
- `streaming`: 是否启用流式加载
|
||||
|
||||
return:
|
||||
HFDataset: 加载完成的 Hugging Face 数据集对象。
|
||||
|
||||
example:
|
||||
>>> plugin = DataLoaderPlugin(args)
|
||||
>>> ds = plugin.auto_load_data({"file_name": "~/data.json", "split": "train"})
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def load_data_from_file(self, filepath: str, split: str, streaming: bool) -> HFDataset:
|
||||
"""从文件或目录加载数据集
|
||||
|
||||
根据输入路径自动识别文件类型(CSV、JSON、Parquet、Text 等),
|
||||
并通过 `datasets.load_dataset()` 加载数据集。
|
||||
若 `streaming=True`,则将结果转换为迭代式数据集。
|
||||
|
||||
args:
|
||||
filepath (str): 文件路径或目录路径。
|
||||
split (str): 数据划分名称(如 "train"、"validation")。
|
||||
streaming (bool): 是否启用流式加载模式。
|
||||
|
||||
return:
|
||||
HFDataset: 加载后的数据集对象。
|
||||
|
||||
example:
|
||||
>>> plugin.load_data_from_file("data/train.json", "train", False)
|
||||
"""
|
||||
...
|
||||
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## DataIndexPlugin
|
||||
|
||||
### 1. DataIndexPlugin 简介
|
||||
|
||||
`DataIndexPlugin` 负责调整数据索引,支持通过配置 `size`, `weight` 等参数控制数据集样本数量和采样频率。
|
||||
|
||||
- 使用 `size` 参数 限制使用的样本数量:
|
||||
|
||||
```yaml
|
||||
my_dataset:
|
||||
file_name: large_dataset.json
|
||||
size: 1000 # 只使用前 1000 个样本
|
||||
```
|
||||
|
||||
- 使用 `weight` 参数调整数据集在混合数据中的采样频率:
|
||||
|
||||
```yaml
|
||||
dataset_a:
|
||||
file_name: data_a.json
|
||||
weight: 1.0
|
||||
|
||||
dataset_b:
|
||||
file_name: data_b.json
|
||||
weight: 2.0 # dataset_b 的样本出现频率是 dataset_a 的 2 倍
|
||||
```
|
||||
|
||||
**说明**:`weight` 参数适用于在多个数据集混合训练时,调整不同数据集的的采样频率
|
||||
|
||||
- 当 `weight=1.0` 时,数据集按原始比例采样
|
||||
- 当 `weight=2.0` 时,该数据集的索引会复制 2 倍,使其样本出现频率翻倍
|
||||
|
||||
### 2. DataIndexPlugin 接口定义
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class DataIndexPlugin:
|
||||
"""数据索引插件(DataIndexPlugin)
|
||||
|
||||
根据 `size` 和 `weight` 调整数据索引列表,控制数据集的样本数量和采样频率
|
||||
通常在多数据集混合训练时使用,以控制不同数据集在总体样本中的占比。
|
||||
|
||||
在 `DataEngine.build_data_index` 中被自动调用,用于实现样本重采样或加权分布。
|
||||
"""
|
||||
|
||||
def adjust_data_index(
|
||||
self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float]
|
||||
) -> list[tuple[str, int]]:
|
||||
"""调整数据索引列表
|
||||
|
||||
根据 `size` 或 `weight` 参数对输入的数据索引进行采样、扩展或缩减。
|
||||
若两个参数同时存在,将依次执行基于大小和基于权重的调整。
|
||||
|
||||
args:
|
||||
data_index (list[tuple[str, int]]):
|
||||
数据索引列表,每个元素为 `(dataset_name, sample_index)`。
|
||||
size (Optional[int]):
|
||||
目标样本数量,若指定则根据该数量裁剪或重复样本。
|
||||
weight (Optional[float]):
|
||||
数据集权重,用于控制数据集在混合训练中的采样比例。
|
||||
|
||||
return:
|
||||
list[tuple[str, int]]:
|
||||
调整后的数据索引列表。
|
||||
|
||||
example:
|
||||
>>> plugin = DataIndexPlugin()
|
||||
>>> adjusted = plugin.adjust_data_index([("ds1", i) for i in range(100)], size=50, weight=None)
|
||||
>>> len(adjusted)
|
||||
50
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def adjust_by_size(self, data_index: list[tuple[str, int]], size: int) -> list[tuple[str, int]]:
|
||||
"""根据目标大小调整数据索引
|
||||
|
||||
通过裁剪或重复样本,使索引总数等于 `size`。
|
||||
常用于统一不同数据集的样本数量。
|
||||
|
||||
args:
|
||||
data_index (list[tuple[str, int]]):
|
||||
原始数据索引列表。
|
||||
size (int):
|
||||
目标样本数量。
|
||||
|
||||
return:
|
||||
list[tuple[str, int]]:
|
||||
调整后长度等于 `size` 的数据索引列表。
|
||||
|
||||
example:
|
||||
>>> plugin.adjust_by_size([("ds1", i) for i in range(10)], 20)
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def adjust_by_weight(self, data_index: list[tuple[str, int]], weight: float) -> list[tuple[str, int]]:
|
||||
"""根据权重调整数据索引
|
||||
|
||||
通过加权采样或重复样本,使数据集样本出现频率符合指定权重。
|
||||
常用于多数据源训练中按比例平衡样本。
|
||||
|
||||
args:
|
||||
data_index (list[tuple[str, int]]):
|
||||
原始数据索引列表。
|
||||
weight (float):
|
||||
数据集权重(相对比例,可与其他数据集共同归一化)。
|
||||
|
||||
return:
|
||||
list[tuple[str, int]]:
|
||||
调整后的加权数据索引列表。
|
||||
|
||||
example:
|
||||
>>> plugin.adjust_by_weight([("ds1", i) for i in range(10)], 0.5)
|
||||
"""
|
||||
...
|
||||
|
||||
```
|
||||
---
|
||||
|
||||
## DataSelectorPlugin
|
||||
|
||||
### 1. DataSelectorPlugin 简介
|
||||
|
||||
`DataSelectorPlugin` 为 `DataEngine`提供基于索引访问数据的功能,由 `DataEngine` 的 `__getitem__` 方法自动调用。
|
||||
|
||||
|
||||
### 2. DataSelectorPlugin 接口定义
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class DataSelectorPlugin:
|
||||
"""根据索引选择数据集样本。
|
||||
|
||||
配合 `DataEngine` 使用,通过统一的 `data_index` 结构(包含数据集名与样本索引)来实现灵活的数据选择
|
||||
|
||||
"""
|
||||
|
||||
data_index: list[tuple[str, int]]
|
||||
"""数据索引列表,每个元素为 (dataset_name, sample_index)。"""
|
||||
|
||||
|
||||
def select(self, index: Union[slice, list[int], Any]) -> Union[tuple[str, int], list[tuple[str, int]]]:
|
||||
"""选择数据集样本
|
||||
|
||||
根据输入类型从 `data_index` 中选择对应的样本索引
|
||||
支持三种索引方式:
|
||||
- 切片(slice):返回对应范围内的样本
|
||||
- 索引列表(list[int]):返回指定索引处的多个样本
|
||||
- 其他类型输入将触发异常。
|
||||
|
||||
args:
|
||||
index (Union[slice, list[int], Any]): 数据样本索引
|
||||
可以是切片(`slice`)或索引列表
|
||||
|
||||
return:
|
||||
Union[tuple[str, int], list[tuple[str, int]]]:
|
||||
- 若为单个索引:返回一个 `(dataset_name, sample_index)`
|
||||
- 若为多个索引或切片:返回多个样本的列表
|
||||
|
||||
except:
|
||||
Raises:
|
||||
ValueError: 当输入索引类型不受支持时抛出。
|
||||
...
|
||||
```
|
||||
197
docs/zh/dev-guide/plugins/model-plugins/kernels.md
Normal file
197
docs/zh/dev-guide/plugins/model-plugins/kernels.md
Normal file
@@ -0,0 +1,197 @@
|
||||
# Kernels plugins
|
||||
|
||||
## 概览
|
||||
LLaMA-Factory 通过 Kernels plugins 系统,依据不同硬件设备提供高性能计算内核(kernel)实现。该系统通过注册表机制管理所有 kernel,通过 `@register_kernel` 装饰器实现 kernel 定义后自动注册,由 `apply_kernel` 方法来使能指定的 kernel,`apply_default_kernels` 可使能注册表中当前环境所有可用的默认 kernels。
|
||||
|
||||
## 架构设计
|
||||
|
||||
### 核心组件
|
||||
|
||||
#### 1. Registry(注册表)
|
||||
|
||||
`Registry` 是一个用于管理所有 kernel 实现的静态类。它维护一个字典结构:`{kernel_id: KernelClass}`。
|
||||
|
||||
```python
|
||||
# 注册表结构示例
|
||||
{
|
||||
"npu_fused_rmsnorm": NpuRMSNormKernel,
|
||||
"npu_fused_swiglu": NpuSwiGluKernel,
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. register_kernel (装饰器)
|
||||
|
||||
`@register_kernel` 是 `Registry.register` 的别名。所有 kernel 类均应使用该装饰器进行注册。
|
||||
|
||||
**注册机制**:
|
||||
- 装饰器检查类是否继承自 `BaseKernel`。
|
||||
- 检查类是否定义了 `_kernel_id` 和 `_device` 属性。
|
||||
- 检查 `_device` 是否与当前运行环境的加速器类型匹配。如果不匹配,则跳过注册。
|
||||
- 如果一切符合要求,将 kernel 类注册到全局注册表中。
|
||||
|
||||
#### 3. BaseKernel(基类)
|
||||
|
||||
所有 kernel 的实现都必须继承自 `BaseKernel` 抽象基类。`BaseKernel` 定义了 kernel 的基本属性和接口。
|
||||
|
||||
#### 4. 标识系统
|
||||
|
||||
**Kernel ID** (`_kernel_id`):
|
||||
每个 kernel 必须拥有一个唯一的字符串标识符,例如 `"npu_fused_rmsnorm"`。
|
||||
|
||||
**Device Type** (`_device`):
|
||||
kernel 必须声明其支持的设备类型,例如 `DeviceType.NPU` 或 `DeviceType.CUDA`。
|
||||
|
||||
## Kernel 系统 API 设计
|
||||
|
||||
### **Registry**:全局 kernel 注册表
|
||||
|
||||
`Registry` 类提供了注册和获取 kernel 的接口:
|
||||
|
||||
```python
|
||||
class Registry:
|
||||
@classmethod
|
||||
def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
|
||||
"""注册一个 kernel 类"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get(cls, kernel_id: str) -> type[BaseKernel] | None:
|
||||
"""根据 ID 获取 kernel 类"""
|
||||
...
|
||||
```
|
||||
|
||||
### **BaseKernel**
|
||||
|
||||
`BaseKernel` 定义了所有 kernel 必须实现的协议:
|
||||
|
||||
- `_kernel_id`: 类属性,kernel 的唯一标识符。
|
||||
- `_device`: 类属性,kernel 支持的设备类型。
|
||||
- `check_deps()`: 类方法,检查 kernel 的依赖项是否满足(如 `torch_npu` 是否安装)。
|
||||
- `apply(**kwargs)`: 抽象类方法,实现 kernel 的具体应用逻辑。
|
||||
|
||||
```python
|
||||
class BaseKernel(ABC):
|
||||
_kernel_id: Any = ""
|
||||
_device: DeviceType = DeviceType.CPU
|
||||
|
||||
@classmethod
|
||||
def check_deps(cls) -> bool:
|
||||
"""检查依赖项"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def apply(cls, **kwargs) -> HFModel:
|
||||
"""应用 kernel 到模型"""
|
||||
...
|
||||
```
|
||||
|
||||
### **scan_all_kernels**
|
||||
|
||||
`scan_all_kernels` 函数会自动扫描 `ops` 目录下的所有 `.py` 文件并导入它们,从而触发 `@register_kernel` 装饰器完成自动注册。
|
||||
|
||||
### **apply_kernel**
|
||||
|
||||
对模型使能指定的 kernel。
|
||||
|
||||
```python
|
||||
def apply_kernel(kernel_id: str, **kwargs) -> HFModel:
|
||||
"""应用指定的 kernel 到模型
|
||||
|
||||
Args:
|
||||
kernel_id: 目标 kernel 的 ID
|
||||
**kwargs: 传递给 kernel.apply 的参数,通常包含 model
|
||||
"""
|
||||
```
|
||||
|
||||
**用法示例**:
|
||||
```python
|
||||
from llamafactory.v1.plugins.model_plugins.kernels import apply_kernel
|
||||
|
||||
model = apply_kernel("npu_fused_rmsnorm", model=model)
|
||||
```
|
||||
|
||||
### **apply_default_kernels**
|
||||
|
||||
对模型使能所有默认注册的 kernel。这是一个高级 API,通常在模型加载流程中自动调用。
|
||||
|
||||
```python
|
||||
def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel:
|
||||
"""应用所有默认 kernel
|
||||
|
||||
Args:
|
||||
model: HFModel 实例
|
||||
include_kernels: 包含的 kernel ID 列表(逗号分隔字符串),或者 "auto"/True 表示全部
|
||||
"""
|
||||
```
|
||||
|
||||
## 扩展 Kernels
|
||||
|
||||
如果用户有针对特定模型或者设备的 kernel,可以按照下述步骤去实现并接入 LLaMA-Factory。
|
||||
|
||||
### 创建新 Kernel 的步骤
|
||||
|
||||
#### 1. 创建 Kernel 实现文件
|
||||
|
||||
在 `src/llamafactory/v1/plugins/model_plugins/kernels/ops` 下的相应子目录中创建新的 kernel 实现文件,例如 `mlp/cuda_swiglu.py`:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from ......accelerator.helper import DeviceType
|
||||
from ......utils.types import HFModel
|
||||
from ...base import BaseKernel
|
||||
from ...registry import register_kernel
|
||||
|
||||
# 实现具体的 kernel 函数
|
||||
def _cuda_swiglu_forward(self, hidden_state):
|
||||
# ... CUDA 优化实现 ...
|
||||
pass
|
||||
|
||||
@register_kernel
|
||||
class CudaSwiGluKernel(BaseKernel):
|
||||
_kernel_id = "cuda_fused_swiglu"
|
||||
_device = DeviceType.CUDA
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> HFModel:
|
||||
model = kwargs.get("model")
|
||||
if model is None:
|
||||
raise ValueError("model is required")
|
||||
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError("Dependencies not met")
|
||||
|
||||
# 遍历模型并替换 forward 方法
|
||||
for name, module in model.named_modules():
|
||||
# ... 匹配和替换逻辑 ...
|
||||
pass
|
||||
|
||||
return model
|
||||
```
|
||||
|
||||
#### 2. 自动发现
|
||||
|
||||
由于 `scan_all_kernels` 会自动扫描 `ops` 目录,只要文件位于该目录下且没有语法错误,系统启动时会自动导入并注册,无需手动修改注册表代码。
|
||||
|
||||
#### 3. 测试 Kernel
|
||||
|
||||
创建测试用例验证 kernel 的正确性:
|
||||
|
||||
```python
|
||||
from llamafactory.v1.plugins.model_plugins.kernels import apply_kernel
|
||||
|
||||
# ... 加载模型 ...
|
||||
model = apply_kernel("cuda_fused_swiglu", model=model)
|
||||
# ... 验证 forward 是否被替换 ...
|
||||
```
|
||||
|
||||
## 异常处理
|
||||
|
||||
### 依赖不可用
|
||||
|
||||
`BaseKernel.check_deps()` 默认会检查当前设备类型是否匹配。子类可以重写此方法以添加额外的依赖检查(如检查特定的库是否安装)。如果 `check_deps()` 返回 `False`,`apply()` 方法应当抛出异常或进行相应处理。
|
||||
|
||||
### Kernel ID 未找到
|
||||
|
||||
如果调用 `apply_kernel` 时传入了不存在的 `kernel_id`,会抛出 `ValueError`。
|
||||
71
docs/zh/getting-started.md
Normal file
71
docs/zh/getting-started.md
Normal file
@@ -0,0 +1,71 @@
|
||||
# Getting Started
|
||||
|
||||
|
||||
## 训练方法
|
||||
|
||||
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
|
||||
|:---------------------:| ------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| 指令监督微调 | :white_check_mark: | | | |
|
||||
| 奖励模型训练 | | | | |
|
||||
| DPO 训练 | | | | |
|
||||
|
||||
|
||||
|
||||
|
||||
## 软件依赖
|
||||
|
||||
| 必需项 | 至少 | 推荐 |
|
||||
|:---------------------:|--------|--------|
|
||||
| python | 3.11 | 3.12 |
|
||||
| torch | 2.7.1 | 2.7.1 |
|
||||
| torch-npu(Ascend NPU) | 2.7.1 | 2.7.1 |
|
||||
| torchvision | 0.22.1 | 0.22.1 |
|
||||
| transformers | 5.0.0 | 5.0.0 |
|
||||
| datasets | 3.2.0 | 4.0.0 |
|
||||
| peft | 0.18.1 | 0.18.1 |
|
||||
|
||||
|
||||
| 可选项 | 至少 | 推荐 |
|
||||
|:----------------:|--------|--------|
|
||||
| CUDA(NVIDIA GPU) | 11.6 | 12.2 |
|
||||
| deepspeed | 0.18.4 | 0.18.4 |
|
||||
| flash-attn(NVIDIA GPU) | 2.5.6 | 2.7.2 |
|
||||
|
||||
|
||||
## 如何使用
|
||||
|
||||
### 安装 LLaMA Factory
|
||||
|
||||
> [!IMPORTANT]
|
||||
> 此步骤为必需。
|
||||
|
||||
#### 从源码安装
|
||||
|
||||
```bash
|
||||
git clone --depth 1 https://github.com/hiyouga/LlamaFactory.git
|
||||
cd LlamaFactory
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|
||||
### 数据准备
|
||||
|
||||
关于数据集文件的格式,请参考 [data-preparation/README.md](data-preparation/README.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集。
|
||||
|
||||
> [!NOTE]
|
||||
> 使用自定义数据集或自定义数据集格式时,请参照 [data-preparation/README.md](data-preparation/README.md) 进行配置,如有必要,请重新实现自定义数据集的数据处理逻辑,包括对应的`converter`。
|
||||
|
||||
您也可以使用 **[Easy Dataset](https://github.com/ConardLi/easy-dataset)**、**[DataFlow](https://github.com/OpenDCAI/DataFlow)** 和 **[GraphGen](https://github.com/open-sciencelab/GraphGen)** 构建用于微调的合成数据。
|
||||
|
||||
### 快速开始
|
||||
|
||||
下面的命令展示了对 Qwen3-0.6B 模型使用 FSDP2 进行 全参**微调**,两行命令等价。
|
||||
|
||||
```bash
|
||||
export USE_V1=1
|
||||
llamafactory-cli sft examples/v1/train_full/train_full_fsdp2.yaml
|
||||
llamafactory-cli train examples/v1/train_full/train_full_fsdp2.yaml
|
||||
|
||||
```
|
||||
|
||||
高级用法请参考 [advanced](./advanced/README.md)(包括多卡多机微调、分布式、Lora、量化、以及各种加速特性等)。
|
||||
1
docs/zh/hyperparameters/data-argument.md
Normal file
1
docs/zh/hyperparameters/data-argument.md
Normal file
@@ -0,0 +1 @@
|
||||
# Data Argument
|
||||
0
docs/zh/hyperparameters/model-argument.md
Normal file
0
docs/zh/hyperparameters/model-argument.md
Normal file
0
docs/zh/hyperparameters/sample-argument.md
Normal file
0
docs/zh/hyperparameters/sample-argument.md
Normal file
0
docs/zh/hyperparameters/training-argument.md
Normal file
0
docs/zh/hyperparameters/training-argument.md
Normal file
62
docs/zh/index.rst
Normal file
62
docs/zh/index.rst
Normal file
@@ -0,0 +1,62 @@
|
||||
LlamaFactory 文档
|
||||
=================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Getting Started
|
||||
|
||||
getting-started
|
||||
installation
|
||||
llamaboard-web-ui
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Data Preparation
|
||||
|
||||
data-preparation/data-processing
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Training
|
||||
|
||||
training/sft
|
||||
training/dpo
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Inference
|
||||
|
||||
inference/deploy
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Advanced
|
||||
|
||||
advanced/lora-and-quantization/lora
|
||||
advanced/lora-and-quantization/quantization
|
||||
advanced/distributed/fsdp
|
||||
advanced/distributed/deepspeed
|
||||
advanced/distributed/parallel-dp-tp-ep-sp-cp
|
||||
advanced/custom-kernels/triton
|
||||
advanced/custom-kernels/fused-operators
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Hyperparameters
|
||||
|
||||
hyperparameters/data-argument
|
||||
hyperparameters/model-argument
|
||||
hyperparameters/sample-argument
|
||||
hyperparameters/training-argument
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Dev Guide
|
||||
|
||||
dev-guide/core/data-engine
|
||||
dev-guide/core/model-engine
|
||||
dev-guide/core/trainer
|
||||
dev-guide/plugins/data-plugins
|
||||
dev-guide/plugins/model-plugins/initialization
|
||||
dev-guide/plugins/model-plugins/kernels
|
||||
dev-guide/plugins/model-plugins/rendering
|
||||
1
docs/zh/inference/deploy.md
Normal file
1
docs/zh/inference/deploy.md
Normal file
@@ -0,0 +1 @@
|
||||
# Deploy
|
||||
1
docs/zh/installation.md
Normal file
1
docs/zh/installation.md
Normal file
@@ -0,0 +1 @@
|
||||
# Installation
|
||||
1
docs/zh/llamaboard-web-ui.md
Normal file
1
docs/zh/llamaboard-web-ui.md
Normal file
@@ -0,0 +1 @@
|
||||
# LlamaBoard Web UI
|
||||
1
docs/zh/training/dpo.md
Normal file
1
docs/zh/training/dpo.md
Normal file
@@ -0,0 +1 @@
|
||||
# DPO
|
||||
1
docs/zh/training/sft.md
Normal file
1
docs/zh/training/sft.md
Normal file
@@ -0,0 +1 @@
|
||||
# SFT
|
||||
45
examples/extras/asft/llama2_full_asft.yaml
Normal file
45
examples/extras/asft/llama2_full_asft.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
### model
|
||||
model_name_or_path: models/Llama-2-7b
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||
use_asft_loss: true
|
||||
asft_alpha: 0.1
|
||||
|
||||
### dataset
|
||||
dataset: med
|
||||
template: llama2
|
||||
cutoff_len: 2048
|
||||
max_samples: 10000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/llama2-7b/full/asft2
|
||||
logging_steps: 1
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 4
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 2.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
||||
45
examples/extras/asft/qwen2_full_asft.yaml
Normal file
45
examples/extras/asft/qwen2_full_asft.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
### model
|
||||
model_name_or_path: models/Qwen2.5-7B
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
deepspeed: examples/deepspeed/ds_z0_config.json
|
||||
use_asft_loss: true
|
||||
asft_alpha: 0.05
|
||||
|
||||
### dataset
|
||||
dataset: math
|
||||
template: qwen
|
||||
cutoff_len: 2048
|
||||
max_samples: 10000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
dataloader_num_workers: 4
|
||||
|
||||
### output
|
||||
output_dir: saves/qwen2-7b/full/asft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 4
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 5.0e-5
|
||||
num_train_epochs: 1.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
# val_size: 0.1
|
||||
# per_device_eval_batch_size: 1
|
||||
# eval_strategy: steps
|
||||
# eval_steps: 500
|
||||
@@ -28,12 +28,7 @@ save_only_model: false
|
||||
report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
|
||||
|
||||
### ray
|
||||
ray_run_name: qwen3_4b_sft_lora
|
||||
ray_storage_path: ./saves
|
||||
ray_num_workers: 4 # Number of GPUs to use.
|
||||
placement_strategy: PACK
|
||||
resources_per_worker:
|
||||
GPU: 1
|
||||
# ray_init_kwargs:
|
||||
# runtime_env:
|
||||
# env_vars:
|
||||
|
||||
38
examples/v1/train_freeze/train_freeze_sft.yaml
Normal file
38
examples/v1/train_freeze/train_freeze_sft.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
# Freeze Configuration
|
||||
peft_config:
|
||||
name: freeze
|
||||
freeze_trainable_layers: 2 # Train the last 2 layers
|
||||
freeze_trainable_modules: all # In these layers, train specific modules
|
||||
freeze_extra_modules: null # Extra modules to train (e.g. embed_tokens, lm_head)
|
||||
|
||||
# Kernel Config
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto
|
||||
|
||||
# FSDP Config
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: ./outputs/test_freeze
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 4
|
||||
cutoff_len: 2048
|
||||
learning_rate: 2.0e-5
|
||||
bf16: false
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
24
examples/v1/train_full/train_full_deepspeed.yaml
Normal file
24
examples/v1/train_full/train_full_deepspeed.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto
|
||||
|
||||
dist_config:
|
||||
name: deepspeed
|
||||
config_file: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/Qwen3-0.6B-deepspeed
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: true
|
||||
max_steps: 10
|
||||
@@ -14,16 +14,12 @@ dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
init_config:
|
||||
name: init_on_meta
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
|
||||
7
examples/v1/train_lora/export_lora.yaml
Normal file
7
examples/v1/train_lora/export_lora.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
peft_config:
|
||||
name: lora
|
||||
adapter_name_or_path: ./outputs/test_lora
|
||||
export_dir: ./merge_lora_model
|
||||
export_size: 5
|
||||
infer_dtype: auto
|
||||
39
examples/v1/train_lora/train_lora_sft.yaml
Normal file
39
examples/v1/train_lora/train_lora_sft.yaml
Normal file
@@ -0,0 +1,39 @@
|
||||
model: Qwen/Qwen3-4B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
# PEFT Configuration
|
||||
peft_config:
|
||||
name: lora
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
target_modules: all
|
||||
|
||||
# Kernel Config
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto
|
||||
|
||||
# FSDP Config
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: ./outputs/test_lora
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 4
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: true
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
43
examples/v1/train_qlora/quantization.yaml
Normal file
43
examples/v1/train_qlora/quantization.yaml
Normal file
@@ -0,0 +1,43 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
# PEFT Configuration
|
||||
peft_config:
|
||||
name: lora
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
target_modules: all
|
||||
|
||||
# Kernel Config
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto
|
||||
|
||||
# FSDP Config
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null
|
||||
|
||||
# Quantization Config
|
||||
quant_config:
|
||||
name: bnb # choice: auto/bnb if auto is selected, the quantization method will be automatically selected based on the model and environment.
|
||||
quantization_bit: 4 # choice: 8/4(bnb)
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_quantization
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -40,7 +40,7 @@ dependencies = [
|
||||
"torch>=2.4.0",
|
||||
"torchvision>=0.19.0",
|
||||
"torchaudio>=2.4.0",
|
||||
"transformers>=4.51.0,<=5.0.0,!=4.52.0,!=4.57.0",
|
||||
"transformers>=4.55.0,<=5.2.0,!=4.52.0,!=4.57.0",
|
||||
"datasets>=2.16.0,<=4.0.0",
|
||||
"accelerate>=1.3.0,<=1.11.0",
|
||||
"peft>=0.18.0,<=0.18.1",
|
||||
|
||||
@@ -1 +1 @@
|
||||
liger-kernel>=0.5.5
|
||||
liger-kernel>=0.6.3
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
torch==2.7.1
|
||||
torch-npu==2.7.1
|
||||
torch-npu==2.7.1.post2
|
||||
torchvision==0.22.1
|
||||
torchaudio==2.7.1
|
||||
|
||||
@@ -71,6 +71,7 @@ def convert(
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
expert_model_parallel_size: int = 1,
|
||||
virtual_pipeline_model_parallel_size: int | None = None,
|
||||
moe_grouped_gemm: bool | None = None,
|
||||
):
|
||||
"""Convert checkpoint between MCA and HuggingFace formats.
|
||||
|
||||
@@ -84,6 +85,10 @@ def convert(
|
||||
pipeline_model_parallel_size: Pipeline model parallel size
|
||||
expert_model_parallel_size: Expert model parallel size
|
||||
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
|
||||
moe_grouped_gemm: Use grouped gemm for MoE experts. When enabled, expert
|
||||
weights are stored in a flattened format (linear_fc1.weight0, weight1, ...)
|
||||
rather than per-expert format (local_experts.0.linear_fc1.weight, ...).
|
||||
Must match the format used when saving the checkpoint.
|
||||
"""
|
||||
if bf16 and fp16:
|
||||
raise ValueError("bf16 and fp16 cannot be both True.")
|
||||
@@ -97,8 +102,9 @@ def convert(
|
||||
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
||||
expert_model_parallel_size=expert_model_parallel_size,
|
||||
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
|
||||
moe_grouped_gemm=moe_grouped_gemm,
|
||||
transformer_impl="transformer_engine", # hard code here since we default using te for training
|
||||
)
|
||||
|
||||
convert_checkpoint_to_mca(
|
||||
checkpoint_path,
|
||||
output_path,
|
||||
|
||||
@@ -154,25 +154,24 @@ def vllm_infer(
|
||||
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
|
||||
|
||||
for j in range(len(batch["input_ids"])):
|
||||
multi_modal_data = {}
|
||||
video_metadata_kwargs = None
|
||||
|
||||
if batch["images"][j] is not None:
|
||||
image = batch["images"][j]
|
||||
multi_modal_data = {
|
||||
"image": template_obj.mm_plugin._regularize_images(
|
||||
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||
)["images"]
|
||||
}
|
||||
elif batch["videos"][j] is not None:
|
||||
video_metadata, video_metadata_kwargs = None, None
|
||||
multi_modal_data["image"] = template_obj.mm_plugin._regularize_images(
|
||||
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||
)["images"]
|
||||
|
||||
if batch["videos"][j] is not None:
|
||||
video = batch["videos"][j]
|
||||
multi_modal_data = {
|
||||
"video": template_obj.mm_plugin._regularize_videos(
|
||||
video,
|
||||
image_max_pixels=image_max_pixels,
|
||||
image_min_pixels=image_min_pixels,
|
||||
video_fps=video_fps,
|
||||
video_maxlen=video_maxlen,
|
||||
)["videos"]
|
||||
}
|
||||
multi_modal_data["video"] = template_obj.mm_plugin._regularize_videos(
|
||||
video,
|
||||
image_max_pixels=image_max_pixels,
|
||||
image_min_pixels=image_min_pixels,
|
||||
video_fps=video_fps,
|
||||
video_maxlen=video_maxlen,
|
||||
)["videos"]
|
||||
if need_video_kwargs:
|
||||
container = av.open(video[0], "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
@@ -192,18 +191,17 @@ def vllm_infer(
|
||||
video_backend="opencv",
|
||||
)
|
||||
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
|
||||
elif batch["audios"][j] is not None:
|
||||
|
||||
if batch["audios"][j] is not None:
|
||||
audio = batch["audios"][j]
|
||||
audio_data = template_obj.mm_plugin._regularize_audios(
|
||||
audio,
|
||||
sampling_rate=16000,
|
||||
)
|
||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
|
||||
|
||||
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
|
||||
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
|
||||
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data or None}
|
||||
if video_metadata_kwargs is not None:
|
||||
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
|
||||
|
||||
vllm_inputs.append(vllm_input_data)
|
||||
|
||||
@@ -88,7 +88,10 @@ def _process_request(
|
||||
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
content = request.messages.pop(0).content
|
||||
system = content[0].text if isinstance(content, list) else content
|
||||
if isinstance(content, list):
|
||||
system = content[0].text if content else ""
|
||||
else:
|
||||
system = content
|
||||
else:
|
||||
system = None
|
||||
|
||||
|
||||
@@ -180,35 +180,32 @@ class VllmEngine(BaseEngine):
|
||||
else self.generating_args["skip_special_tokens"],
|
||||
)
|
||||
|
||||
multi_modal_data = {}
|
||||
if images is not None: # add image features
|
||||
multi_modal_data = {
|
||||
"image": self.template.mm_plugin._regularize_images(
|
||||
images,
|
||||
image_max_pixels=self.model_args.image_max_pixels,
|
||||
image_min_pixels=self.model_args.image_min_pixels,
|
||||
)["images"]
|
||||
}
|
||||
elif videos is not None:
|
||||
multi_modal_data = {
|
||||
"video": self.template.mm_plugin._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=self.model_args.video_max_pixels,
|
||||
image_min_pixels=self.model_args.video_min_pixels,
|
||||
video_fps=self.model_args.video_fps,
|
||||
video_maxlen=self.model_args.video_maxlen,
|
||||
)["videos"]
|
||||
}
|
||||
elif audios is not None:
|
||||
multi_modal_data["image"] = self.template.mm_plugin._regularize_images(
|
||||
images,
|
||||
image_max_pixels=self.model_args.image_max_pixels,
|
||||
image_min_pixels=self.model_args.image_min_pixels,
|
||||
)["images"]
|
||||
|
||||
if videos is not None:
|
||||
multi_modal_data["video"] = self.template.mm_plugin._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=self.model_args.video_max_pixels,
|
||||
image_min_pixels=self.model_args.video_min_pixels,
|
||||
video_fps=self.model_args.video_fps,
|
||||
video_maxlen=self.model_args.video_maxlen,
|
||||
)["videos"]
|
||||
|
||||
if audios is not None:
|
||||
audio_data = self.template.mm_plugin._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=self.model_args.audio_sampling_rate,
|
||||
)
|
||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
|
||||
|
||||
result_generator = self.model.generate(
|
||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data or None},
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=self.lora_request,
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
@@ -24,7 +26,7 @@ import torch.nn.functional as F
|
||||
from peft import PeftModel
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, MROPE_MODELS
|
||||
from ..extras.packages import is_pillow_available
|
||||
|
||||
|
||||
@@ -38,6 +40,56 @@ if TYPE_CHECKING:
|
||||
from .template import Template
|
||||
|
||||
|
||||
def _slice_mm_inputs_for_sample(
|
||||
mm_inputs: dict[str, Any],
|
||||
batch_imglens: list[int],
|
||||
batch_vidlens: list[int],
|
||||
batch_idx: int,
|
||||
images_per_subseq: Optional[list[int]] = None,
|
||||
videos_per_subseq: Optional[list[int]] = None,
|
||||
subseq_idx: Optional[int] = None,
|
||||
) -> dict[str, Any]:
|
||||
r"""Slice mm_inputs for one batch sample, optionally for a single sub-sequence when packing.
|
||||
|
||||
image_grid_thw / video_grid_thw have shape [num_items, 3]. Indices for sample batch_idx
|
||||
are batch_imglens[batch_idx] images and batch_vidlens[batch_idx] videos. When subseq_idx
|
||||
is given, further restrict to that sub-seq's counts via packed_*_counts.
|
||||
has_dummy_image=True means only batch[0] will be concated with fake image and no multimodal data.
|
||||
"""
|
||||
image_start_idx = sum(batch_imglens[:batch_idx])
|
||||
image_end_idx = sum(batch_imglens[: batch_idx + 1])
|
||||
video_start_idx = sum(batch_vidlens[:batch_idx])
|
||||
video_end_idx = sum(batch_vidlens[: batch_idx + 1])
|
||||
|
||||
if subseq_idx is not None and images_per_subseq is not None:
|
||||
image_start_idx += sum(images_per_subseq[:subseq_idx])
|
||||
image_end_idx = image_start_idx + images_per_subseq[subseq_idx]
|
||||
|
||||
if subseq_idx is not None and videos_per_subseq is not None:
|
||||
video_start_idx += sum(videos_per_subseq[:subseq_idx])
|
||||
video_end_idx = video_start_idx + videos_per_subseq[subseq_idx]
|
||||
|
||||
sliced_mm_inputs: dict[str, Any] = {}
|
||||
key_to_slice_meta = {
|
||||
"image_grid_thw": (image_start_idx, image_end_idx, True),
|
||||
"video_grid_thw": (video_start_idx, video_end_idx, True),
|
||||
"second_per_grid_ts": (video_start_idx, video_end_idx, False), # qwen2.5vl
|
||||
"video_second_per_grid": (video_start_idx, video_end_idx, False), # qwen omni
|
||||
}
|
||||
|
||||
for key, (start_idx, end_idx, assign_none_when_empty) in key_to_slice_meta.items():
|
||||
if key not in mm_inputs:
|
||||
continue
|
||||
|
||||
mm_value = mm_inputs[key]
|
||||
if mm_value is not None and end_idx > start_idx:
|
||||
sliced_mm_inputs[key] = mm_value[start_idx:end_idx]
|
||||
elif assign_none_when_empty:
|
||||
sliced_mm_inputs[key] = None
|
||||
|
||||
return sliced_mm_inputs
|
||||
|
||||
|
||||
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||
r"""Expand 2d attention mask to 4d attention mask.
|
||||
|
||||
@@ -105,9 +157,154 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
else:
|
||||
self.get_rope_func = None
|
||||
|
||||
def _compute_rope_position_ids(
|
||||
self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any]
|
||||
) -> None:
|
||||
r"""Compute position_ids and rope_deltas via get_rope_func for VLMs."""
|
||||
rope_index_kwargs = {
|
||||
"input_ids": features["input_ids"],
|
||||
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
||||
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
||||
"attention_mask": (features["attention_mask"] >= 1).float(),
|
||||
}
|
||||
if features["attention_mask"].sum() == 0:
|
||||
features["position_ids"] = torch.zeros((3, *features["input_ids"].shape))
|
||||
features["rope_deltas"] = torch.zeros(features["input_ids"].shape[0])
|
||||
return
|
||||
|
||||
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
|
||||
image_token_id = getattr(self.model.config, "image_token_id", None)
|
||||
video_token_id = getattr(self.model.config, "video_token_id", None)
|
||||
if image_token_id is not None or video_token_id is not None:
|
||||
mm_token_type_ids = torch.zeros_like(features["input_ids"])
|
||||
if image_token_id is not None:
|
||||
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
|
||||
if video_token_id is not None:
|
||||
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
|
||||
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
|
||||
|
||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||
|
||||
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
|
||||
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||
if feature_attention_mask is not None: # FIXME: need to get video image lengths
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
||||
|
||||
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
||||
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
|
||||
dim=-1
|
||||
).unsqueeze(-1)
|
||||
else: # for qwen vl
|
||||
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
||||
|
||||
def _compute_rope_position_ids_with_packing(
|
||||
self,
|
||||
features: dict[str, "torch.Tensor"],
|
||||
mm_inputs: dict[str, Any],
|
||||
packing_params_list: list[dict[str, Any] | None],
|
||||
batch_imglens: list[int],
|
||||
batch_vidlens: list[int],
|
||||
batch_audlens: list[int],
|
||||
has_dummy_image: bool,
|
||||
) -> None:
|
||||
r"""Compute position_ids and rope_deltas per sample (or per sub-sequence when packed), then merge and validate."""
|
||||
bsz = features["input_ids"].size(0)
|
||||
seq_len = features["input_ids"].size(1)
|
||||
all_position_ids: list[torch.Tensor] = []
|
||||
all_rope_deltas: list[torch.Tensor] = []
|
||||
|
||||
if has_dummy_image:
|
||||
# for [0, seq_len] = [0, unpadded_length + right_padding_length + fake_input_ids_len + collator_padding_length]
|
||||
# FIXME: maybe right_padding_length is large, with improper max_cutoff_len
|
||||
unpadded_length = int(features["attention_mask"][0].bool().sum().item())
|
||||
right_padding_length = int((packing_params_list[0] or {}).get("right_padding_length") or 0)
|
||||
fake_input_padding_length = max(0, seq_len - unpadded_length - right_padding_length)
|
||||
dummy_image_right_padding_mrope = torch.zeros((3, bsz, fake_input_padding_length))
|
||||
dummy_image_right_padding_attention_mask = torch.zeros((bsz, fake_input_padding_length))
|
||||
assert self.tokenizer.padding_side == "right", "padding_side should be right when fake image is injected"
|
||||
dummy_mm_inputs = copy.deepcopy(mm_inputs)
|
||||
|
||||
for sample_idx in range(bsz):
|
||||
sample_packing = (packing_params_list[sample_idx] or {}) if sample_idx < len(packing_params_list) else {}
|
||||
sequence_boundaries = sample_packing.get("sequence_boundaries")
|
||||
num_sub_seqs = (len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1
|
||||
image_subseq_ids = sample_packing.get("image_subseq_ids") or []
|
||||
video_subseq_ids = sample_packing.get("video_subseq_ids") or []
|
||||
images_per_subseq = (
|
||||
[image_subseq_ids.count(i) for i in range(num_sub_seqs)] if image_subseq_ids and num_sub_seqs > 1 else None
|
||||
)
|
||||
videos_per_subseq = (
|
||||
[video_subseq_ids.count(i) for i in range(num_sub_seqs)] if video_subseq_ids and num_sub_seqs > 1 else None
|
||||
)
|
||||
if has_dummy_image:
|
||||
mm_inputs = {}
|
||||
|
||||
if num_sub_seqs <= 1:
|
||||
sample_features = {
|
||||
"input_ids": features["input_ids"],
|
||||
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1],
|
||||
}
|
||||
mm_inputs_for_sample = _slice_mm_inputs_for_sample(
|
||||
mm_inputs, batch_imglens, batch_vidlens, sample_idx=sample_idx
|
||||
)
|
||||
self._compute_rope_position_ids(sample_features, mm_inputs_for_sample)
|
||||
all_position_ids.append(sample_features["position_ids"])
|
||||
all_rope_deltas.append(sample_features["rope_deltas"])
|
||||
else:
|
||||
# when we do packing, don't need rope_deltas when training.
|
||||
sample_position_ids: list[torch.Tensor] = []
|
||||
for subseq_idx in range(num_sub_seqs):
|
||||
subseq_start = sequence_boundaries[subseq_idx]
|
||||
subseq_end = sequence_boundaries[subseq_idx + 1]
|
||||
subseq_features = {
|
||||
"input_ids": features["input_ids"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
|
||||
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
|
||||
}
|
||||
mm_inputs_for_subseq = _slice_mm_inputs_for_sample(
|
||||
mm_inputs,
|
||||
batch_imglens,
|
||||
batch_vidlens,
|
||||
sample_idx,
|
||||
images_per_subseq,
|
||||
videos_per_subseq,
|
||||
subseq_idx
|
||||
)
|
||||
self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq)
|
||||
sample_position_ids.append(subseq_features["position_ids"])
|
||||
all_position_ids.append(torch.cat(sample_position_ids, dim=-1))
|
||||
|
||||
batch_dim_for_position_ids = 1 if all_position_ids[0].dim() == 3 else 0
|
||||
|
||||
features["position_ids"] = torch.cat(all_position_ids, dim=batch_dim_for_position_ids)
|
||||
if has_dummy_image:
|
||||
mm_inputs = dummy_mm_inputs
|
||||
|
||||
expected_position_ids_shape = (bsz, seq_len) if all_position_ids[0].dim() == 2 else (
|
||||
all_position_ids[0].size(0),
|
||||
bsz,
|
||||
seq_len,
|
||||
)
|
||||
# Check if position_ids shape matches expected shape.
|
||||
# for further usage, we should padding to the right when some padding token on the right.
|
||||
if has_dummy_image:
|
||||
features["position_ids"] = torch.cat([features["position_ids"], dummy_image_right_padding_mrope], dim=-1)
|
||||
features["attention_mask"] = torch.cat([features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1)
|
||||
|
||||
if features["position_ids"].shape != expected_position_ids_shape:
|
||||
raise ValueError(
|
||||
"Merged position_ids shape mismatch: "
|
||||
f"got {features['position_ids'].shape}, expected {expected_position_ids_shape}."
|
||||
)
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_videos, batch_audios = [], [], []
|
||||
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
||||
packing_params_list: list[dict[str, Any] | None] = []
|
||||
for feature in features:
|
||||
images = feature.pop("images", None) or []
|
||||
videos = feature.pop("videos", None) or []
|
||||
@@ -119,8 +316,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
batch_vidlens.append(len(videos))
|
||||
batch_audlens.append(len(audios))
|
||||
batch_input_ids.append(feature["input_ids"])
|
||||
packing_params_list.append(feature.pop("packing_params", None))
|
||||
|
||||
fake_input_ids = []
|
||||
has_dummy_image = False
|
||||
if (
|
||||
self.template.mm_plugin.image_token is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
|
||||
): # avoid process hanging in zero3/fsdp case
|
||||
@@ -136,6 +335,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
fake_input_ids.extend(_fake_input_ids)
|
||||
batch_images = fake_images
|
||||
batch_imglens[0] = 1
|
||||
has_dummy_image = True
|
||||
|
||||
if (
|
||||
self.template.mm_plugin.audio_token is not None and sum(batch_audlens) == 0
|
||||
@@ -182,45 +382,50 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
features: dict[str, torch.Tensor] = super().__call__(features)
|
||||
|
||||
bsz, seq_len = features["input_ids"].shape[:2]
|
||||
model_type = getattr(self.model.config, "model_type", None) if self.model is not None else None
|
||||
is_omni = model_type in [
|
||||
"qwen2_5_omni_thinker",
|
||||
"qwen3_omni_moe_thinker",
|
||||
]
|
||||
|
||||
if self.get_rope_func is not None:
|
||||
rope_index_kwargs = {
|
||||
"input_ids": features["input_ids"],
|
||||
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
||||
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
||||
"attention_mask": (features["attention_mask"] >= 1).float(),
|
||||
}
|
||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||
# for mmrope situation, we should calculate position_ids and rope_deltas per sample.
|
||||
# When neat_packing is on, each sample has packing_params; None means no packing for that sample.
|
||||
boundaries_list = [
|
||||
p.get("sequence_boundaries") if p is not None else None for p in packing_params_list
|
||||
]
|
||||
has_packing = any(b is not None and len(b) > 2 for b in boundaries_list)
|
||||
if has_dummy_image and has_packing:
|
||||
# FIXME: too tricky, need to be refactored
|
||||
features["has_dummy_image"] = True
|
||||
|
||||
if getattr(self.model.config, "model_type", None) in ["qwen2_5_omni_thinker", "qwen3_omni_moe_thinker"]:
|
||||
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||
if feature_attention_mask is not None: # FIXME: need to get video image lengths
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
||||
# When fake image/audio was injected, sequence_boundaries no longer match the tensor; use non-packing path.
|
||||
if not has_packing:
|
||||
self._compute_rope_position_ids(features, mm_inputs)
|
||||
else:
|
||||
if is_omni:
|
||||
raise RuntimeError("Omni models are not supported for packed sequences for now.")
|
||||
|
||||
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
||||
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
|
||||
dim=-1
|
||||
).unsqueeze(-1)
|
||||
else: # for qwen vl
|
||||
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
||||
self._compute_rope_position_ids_with_packing(
|
||||
features,
|
||||
mm_inputs,
|
||||
packing_params_list,
|
||||
batch_imglens,
|
||||
batch_vidlens,
|
||||
batch_audlens,
|
||||
has_dummy_image,
|
||||
)
|
||||
|
||||
# For transformers compatibility, after https://github.com/huggingface/transformers/issues/39400
|
||||
if features["position_ids"].dim() == 3:
|
||||
features["position_ids"] = torch.cat(
|
||||
[features["position_ids"][0].unsqueeze(0), features["position_ids"]], dim=0
|
||||
)
|
||||
|
||||
if (
|
||||
self.model is not None
|
||||
and getattr(self.model.config, "model_type", None)
|
||||
in [
|
||||
"glm4v",
|
||||
"Keye",
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen2_5_omni_thinker",
|
||||
"qwen3_omni_moe_thinker",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
]
|
||||
and getattr(self.model.config, "model_type", None) in MROPE_MODELS
|
||||
and ("position_ids" not in features or features["position_ids"].dim() != 3)
|
||||
):
|
||||
raise ValueError(f"{self.model.config.model_type} requires 3D position ids for mrope.")
|
||||
@@ -248,12 +453,51 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
block_diag_attn: bool = False
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
neat_packing: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.neat_packing and self.attn_implementation == "flash_attention_2":
|
||||
if self.model is not None and getattr(self.model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe", "gpt_oss"]:
|
||||
raise ValueError("Neat packing is not supported for qwen3_5, qwen3_5_moe, gpt_oss models for now.")
|
||||
|
||||
@staticmethod
|
||||
def _unpad_packed_features(features: dict[str, Any]) -> None:
|
||||
r"""Trim padded positions for packed FA2 batches."""
|
||||
attention_mask = features.get("attention_mask")
|
||||
if not torch.is_tensor(attention_mask) or attention_mask.dim() != 2 or attention_mask.size(0) != 1:
|
||||
return
|
||||
|
||||
seq_len = attention_mask.size(1)
|
||||
non_padding_indices = torch.nonzero(attention_mask[0] != 0, as_tuple=False).flatten()
|
||||
if non_padding_indices.numel() == seq_len:
|
||||
return
|
||||
|
||||
keys_on_seq_dim_1 = {"input_ids", "labels", "attention_mask", "token_type_ids"}
|
||||
for key, value in list(features.items()):
|
||||
if not torch.is_tensor(value):
|
||||
continue
|
||||
|
||||
if key == "position_ids" and value.size(-1) == seq_len:
|
||||
features[key] = value.index_select(-1, non_padding_indices)
|
||||
elif key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len:
|
||||
features[key] = value.index_select(1, non_padding_indices)
|
||||
elif key in keys_on_seq_dim_1 and value.dim() == 2 and value.size(0) == 1 and value.size(1) == seq_len:
|
||||
features[key] = value.index_select(1, non_padding_indices)
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
features = super().__call__(features)
|
||||
has_dummy_image = features.pop("has_dummy_image", False)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
|
||||
if self.neat_packing and self.attn_implementation == "flash_attention_2": # FIXME compatibility fa3/fa4
|
||||
assert features["input_ids"].shape[0] == 1, "bsz should be 1 for neat packing"
|
||||
if not has_dummy_image:
|
||||
self._unpad_packed_features(features)
|
||||
|
||||
features["attention_mask"] = None # let transformers handle causal packed mask.
|
||||
|
||||
for key, value in features.items(): # cast data dtype for paligemma
|
||||
if torch.is_tensor(value) and torch.is_floating_point(value):
|
||||
features[key] = value.to(self.compute_dtype)
|
||||
|
||||
@@ -196,7 +196,7 @@ def read_cloud_json(cloud_path: str) -> list[Any]:
|
||||
|
||||
# filter out non-JSON files
|
||||
files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
|
||||
files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files)
|
||||
files = list(filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files))
|
||||
if not files:
|
||||
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
|
||||
|
||||
|
||||
@@ -27,11 +27,12 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, Type
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from transformers.image_utils import get_image_size, is_valid_image, to_numpy_array
|
||||
from transformers.image_utils import get_image_size, is_valid_image, make_flat_list_of_images, to_numpy_array
|
||||
from transformers.models.mllama.processing_mllama import (
|
||||
convert_sparse_cross_attention_mask_to_dense,
|
||||
get_cross_attention_token_mask,
|
||||
)
|
||||
from transformers.video_utils import make_batched_videos
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
@@ -47,13 +48,6 @@ if is_pyav_available():
|
||||
import av
|
||||
|
||||
|
||||
if is_transformers_version_greater_than("4.52.0"):
|
||||
from transformers.image_utils import make_flat_list_of_images
|
||||
from transformers.video_utils import make_batched_videos
|
||||
else:
|
||||
from transformers.image_utils import make_batched_videos, make_flat_list_of_images
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from av.stream import Stream
|
||||
from numpy.typing import NDArray
|
||||
@@ -161,7 +155,9 @@ class MMPluginMixin:
|
||||
video_processor: BaseImageProcessor = getattr(
|
||||
processor, "video_processor", getattr(processor, "image_processor", None)
|
||||
)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||
processor, "audio_processor", None
|
||||
)
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError(
|
||||
"This model does not support image input. Please check whether the correct `template` is used."
|
||||
@@ -390,7 +386,9 @@ class MMPluginMixin:
|
||||
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
||||
|
||||
if len(audios) != 0:
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||
processor, "audio_processor", None
|
||||
)
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
@@ -1054,7 +1052,9 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
chunk_input=True,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
)
|
||||
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
|
||||
audio_feature_lens = [
|
||||
x.clone().detach() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in audio_feature_lens
|
||||
]
|
||||
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
||||
if kwargs.get("ret_phs", False):
|
||||
mm_inputs.update({"audio_phs": audio_phs})
|
||||
@@ -1094,7 +1094,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||
video_seqlen = len(mm_inputs["image_sizes"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
||||
num_video_tokens += 1
|
||||
|
||||
@@ -1876,7 +1876,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||
processor, "audio_processor", None
|
||||
)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
@@ -1981,6 +1983,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
f"Each {VIDEO_PLACEHOLDER} must be followed by an {AUDIO_PLACEHOLDER} when using audio in video."
|
||||
)
|
||||
|
||||
position_id_per_seconds: int = getattr(processor, "position_id_per_seconds", 25)
|
||||
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
||||
video_t_index = (
|
||||
torch.arange(video_grid_thw[num_video_tokens][0])
|
||||
@@ -1992,9 +1995,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
)
|
||||
.flatten()
|
||||
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
||||
* 25 # FIXME hardcode of position_id_per_seconds=25
|
||||
* position_id_per_seconds
|
||||
).long()
|
||||
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
||||
t_ntoken_per_chunk = position_id_per_seconds * 2
|
||||
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
||||
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
||||
placeholder_string = ""
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
@@ -27,6 +27,23 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
|
||||
|
||||
@dataclass
|
||||
class PackingParams:
|
||||
r"""Metadata for a packed sequence: sub-sequence boundaries and multimodal data indices.
|
||||
|
||||
- sequence_boundaries: cumulative token positions, e.g. [0, 100, 250, 512] means 3 sub-seqs
|
||||
with token ranges [0,100), [100,250), [250,512). Length = num_sub_seqs + 1.
|
||||
- image_subseq_ids / video_subseq_ids / audio_subseq_ids: for each mm item, the 0-based
|
||||
sub-sequence index it belongs to. Length = total number of that mm type in the packed sample.
|
||||
"""
|
||||
|
||||
sequence_boundaries: list[int]
|
||||
image_subseq_ids: list[int]
|
||||
video_subseq_ids: list[int]
|
||||
audio_subseq_ids: list[int]
|
||||
right_padding_length: int
|
||||
|
||||
@dataclass
|
||||
class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
@@ -162,10 +179,17 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
valid_num += 1
|
||||
|
||||
model_inputs = defaultdict(list)
|
||||
requires_packing_params = self.data_args.neat_packing
|
||||
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
|
||||
packed_images, packed_videos, packed_audios = [], [], []
|
||||
if requires_packing_params:
|
||||
sequence_boundaries = [0]
|
||||
image_subseq_ids: list[int] = []
|
||||
video_subseq_ids: list[int] = []
|
||||
audio_subseq_ids: list[int] = []
|
||||
|
||||
for i, length in enumerate(knapsack):
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
@@ -174,6 +198,15 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
packed_images += batch_images[index]
|
||||
packed_videos += batch_videos[index]
|
||||
packed_audios += batch_audios[index]
|
||||
if requires_packing_params:
|
||||
n_img = len(batch_images[index])
|
||||
n_vid = len(batch_videos[index])
|
||||
n_aud = len(batch_audios[index])
|
||||
sequence_boundaries.append(sequence_boundaries[-1] + len(batch_input_ids[index]))
|
||||
image_subseq_ids.extend([i] * n_img)
|
||||
video_subseq_ids.extend([i] * n_vid)
|
||||
audio_subseq_ids.extend([i] * n_aud)
|
||||
|
||||
if self.data_args.neat_packing:
|
||||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||
else:
|
||||
@@ -189,10 +222,23 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
else:
|
||||
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||
|
||||
if requires_packing_params:
|
||||
sequence_boundaries.append(sequence_boundaries[-1] + pad_length)
|
||||
|
||||
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
|
||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
if requires_packing_params:
|
||||
packing_params = PackingParams(
|
||||
sequence_boundaries=sequence_boundaries,
|
||||
image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error
|
||||
video_subseq_ids=video_subseq_ids or [MAX_SU_SEQ_IDX],
|
||||
audio_subseq_ids=audio_subseq_ids or [MAX_SU_SEQ_IDX],
|
||||
right_padding_length=pad_length,
|
||||
)
|
||||
model_inputs["packing_params"].append(asdict(packing_params))
|
||||
|
||||
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||
model_inputs["position_ids"].append(packed_position_ids)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
|
||||
@@ -459,6 +459,18 @@ class ReasoningTemplate(Template):
|
||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Glm47ReasoningTemplate(ReasoningTemplate):
|
||||
r"""GLM-4.7 uses only the closing </think> tag for empty thinking blocks."""
|
||||
|
||||
@override
|
||||
def add_thought(self, content: str = "") -> str:
|
||||
if not content:
|
||||
return self.thought_words[1]
|
||||
|
||||
return self.thought_words[0] + content + self.thought_words[1]
|
||||
|
||||
|
||||
TEMPLATES: dict[str, "Template"] = {}
|
||||
|
||||
|
||||
@@ -1049,6 +1061,39 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from glm4 template
|
||||
register_template(
|
||||
name="glm_ocr",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
|
||||
)
|
||||
|
||||
|
||||
# copied from glm4_moe template
|
||||
register_template(
|
||||
name="glm4_7",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4_moe"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
thought_words=("<think>", "</think>"),
|
||||
efficient_eos=True,
|
||||
template_class=Glm47ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
# copied from glm4 template
|
||||
register_template(
|
||||
name="glmz1",
|
||||
@@ -1068,7 +1113,7 @@ register_template(
|
||||
register_template(
|
||||
name="gpt_oss",
|
||||
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
||||
default_system="You are ChatGPT, a large language model trained by OpenAI.",
|
||||
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
|
||||
@@ -1984,6 +2029,39 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="qwen3_5",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen3_5"),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="qwen3_5_nothink",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen3_5"),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="sailor",
|
||||
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
|
||||
@@ -2173,3 +2251,24 @@ register_template(
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||
default_system="You are Zephyr, a helpful assistant.",
|
||||
)
|
||||
|
||||
|
||||
# copied from glm4_7 template
|
||||
register_template(
|
||||
name="aeva",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4_moe"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
default_system=(
|
||||
"You are an AI assistant named Aeva created by Zongzhi Lou. "
|
||||
"Your answer should be friendly, unbiased, faithful, informative and detailed."
|
||||
),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
thought_words=("<think>", "</think>"),
|
||||
efficient_eos=True,
|
||||
template_class=Glm47ReasoningTemplate,
|
||||
)
|
||||
|
||||
@@ -85,6 +85,21 @@ QWEN_TOOL_PROMPT = (
|
||||
""""arguments": <args-json-object>}}\n</tool_call>"""
|
||||
)
|
||||
|
||||
QWEN35_TOOL_PROMPT = (
|
||||
"\n\n# Tools\n\nYou have access to the following functions:\n\n<tools>{tool_text}"
|
||||
"\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n"
|
||||
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n"
|
||||
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n"
|
||||
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n"
|
||||
"- Function calls MUST follow the specified format: "
|
||||
"an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n"
|
||||
"- Required parameters MUST be specified\n"
|
||||
"- You may provide optional reasoning for your function call in natural language "
|
||||
"BEFORE the function call, but NOT after\n"
|
||||
"- If there is no function call available, answer the question like normal with your current knowledge "
|
||||
"and do not tell the user about function calls\n</IMPORTANT>"
|
||||
)
|
||||
|
||||
SEED_TOOL_PROMPT = (
|
||||
"system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query."
|
||||
"Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing "
|
||||
@@ -453,6 +468,57 @@ class QwenToolUtils(ToolUtils):
|
||||
return results
|
||||
|
||||
|
||||
class Qwen35ToolUtils(ToolUtils):
|
||||
r"""Qwen 3.5 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
|
||||
tool_text += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||
|
||||
return QWEN35_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for func in functions:
|
||||
name, arguments = func.name, json.loads(func.arguments)
|
||||
prompt = f"<tool_call>\n<function={name}>"
|
||||
for key, value in arguments.items():
|
||||
prompt += f"\n<parameter={key}>"
|
||||
if not isinstance(value, str):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
prompt += f"\n{value}\n</parameter>"
|
||||
prompt += "\n</function>\n</tool_call>"
|
||||
function_texts.append(prompt)
|
||||
|
||||
return "\n".join(function_texts)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
results = []
|
||||
regex = re.compile(r"<tool_call>\s*<function=\s*([^\s<>]+)\s*(.*?)\s*</function>\s*</tool_call>", re.DOTALL)
|
||||
for func_name, params_block in re.findall(regex, content):
|
||||
args_dict = {}
|
||||
param_pattern = re.compile(r"<parameter=(.*?)>(.*?)</parameter>", re.DOTALL)
|
||||
for key, raw_value in re.findall(param_pattern, params_block.strip()):
|
||||
value = raw_value.strip()
|
||||
try:
|
||||
parsed_value = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
parsed_value = raw_value.strip()
|
||||
args_dict[key] = parsed_value
|
||||
|
||||
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
|
||||
|
||||
return results if results else content
|
||||
|
||||
|
||||
class GLM4MOEToolUtils(QwenToolUtils):
|
||||
r"""GLM-4-MOE tool using template."""
|
||||
|
||||
@@ -662,6 +728,7 @@ TOOLS = {
|
||||
"minimax2": MiniMaxM2ToolUtils(),
|
||||
"mistral": MistralToolUtils(),
|
||||
"qwen": QwenToolUtils(),
|
||||
"qwen3_5": Qwen35ToolUtils(),
|
||||
"glm4_moe": GLM4MOEToolUtils(),
|
||||
"seed_oss": SeedToolUtils(),
|
||||
"ling": LingToolUtils(),
|
||||
|
||||
@@ -65,15 +65,32 @@ MCA_SUPPORTED_MODELS = {
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
"qwen3",
|
||||
"qwen3_moe",
|
||||
"qwen3_next",
|
||||
"qwen3_5",
|
||||
"qwen3_5_moe",
|
||||
}
|
||||
|
||||
METHODS = ["full", "freeze", "lora", "oft"]
|
||||
|
||||
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
||||
|
||||
MROPE_MODELS = {
|
||||
"glm4v",
|
||||
"glm_ocr",
|
||||
"Keye",
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen2_5_omni_thinker",
|
||||
"qwen3_omni_moe_thinker",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
"qwen3_5",
|
||||
"qwen3_5_moe",
|
||||
}
|
||||
|
||||
MULTIMODAL_SUPPORTED_MODELS = set()
|
||||
|
||||
PEFT_METHODS = {"lora", "oft"}
|
||||
@@ -939,6 +956,29 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GLM-4.7-Flash": {
|
||||
DownloadSource.DEFAULT: "zai-org/GLM-4.7-Flash",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-4.7-Flash",
|
||||
},
|
||||
},
|
||||
template="glm4_7",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GLM-OCR": {
|
||||
DownloadSource.DEFAULT: "zai-org/GLM-OCR",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/GLM-OCR",
|
||||
},
|
||||
},
|
||||
template="glm_ocr",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"GLM-Z1-0414-9B-Chat": {
|
||||
@@ -2786,6 +2826,66 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen3.5-0.8B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B-Base",
|
||||
},
|
||||
"Qwen3.5-2B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B-Base",
|
||||
},
|
||||
"Qwen3.5-4B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B-Base",
|
||||
},
|
||||
"Qwen3.5-9B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B-Base",
|
||||
},
|
||||
"Qwen3.5-35B-A3B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B-Base",
|
||||
},
|
||||
"Qwen3.5-0.8B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B",
|
||||
},
|
||||
"Qwen3.5-2B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B",
|
||||
},
|
||||
"Qwen3.5-4B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B",
|
||||
},
|
||||
"Qwen3.5-9B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B",
|
||||
},
|
||||
"Qwen3.5-27B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-27B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-27B",
|
||||
},
|
||||
"Qwen3.5-35B-A3B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B",
|
||||
},
|
||||
"Qwen3.5-122B-A10B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-122B-A10B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-122B-A10B",
|
||||
},
|
||||
"Qwen3.5-397B-A17B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-397B-A17B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-397B-A17B",
|
||||
},
|
||||
},
|
||||
template="qwen3_5",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen2-Audio-7B": {
|
||||
@@ -3427,3 +3527,35 @@ register_model_group(
|
||||
},
|
||||
template="zephyr",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Aeva-Flash-Chat": {
|
||||
DownloadSource.DEFAULT: "louzongzhi/Aeva-Flash",
|
||||
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Flash",
|
||||
DownloadSource.OPENMIND: "louzongzhi/Aeva-Flash",
|
||||
},
|
||||
"Aeva-Air-Chat": {
|
||||
DownloadSource.DEFAULT: "louzongzhi/Aeva-Air",
|
||||
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Air",
|
||||
DownloadSource.OPENMIND: "louzongzhi/Aeva-Air",
|
||||
},
|
||||
"Aeva-Chat": {
|
||||
DownloadSource.DEFAULT: "louzongzhi/Aeva",
|
||||
DownloadSource.MODELSCOPE: "louzongktsi/Aeva",
|
||||
DownloadSource.OPENMIND: "louzongzhi/Aeva",
|
||||
},
|
||||
"Aeva-Pro-Chat": {
|
||||
DownloadSource.DEFAULT: "louzongzhi/Aeva-Pro",
|
||||
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Pro",
|
||||
DownloadSource.OPENMIND: "louzongzhi/Aeva-Pro",
|
||||
},
|
||||
"Aeva-Max-Chat": {
|
||||
DownloadSource.DEFAULT: "louzongzhi/Aeva-Max",
|
||||
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Max",
|
||||
DownloadSource.OPENMIND: "louzongzhi/Aeva-Max",
|
||||
},
|
||||
},
|
||||
template="aeva",
|
||||
)
|
||||
|
||||
@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.51.0,<=5.0.0")
|
||||
check_version("transformers>=4.55.0,<=5.2.0")
|
||||
check_version("datasets>=2.16.0,<=4.0.0")
|
||||
check_version("accelerate>=1.3.0,<=1.11.0")
|
||||
check_version("peft>=0.18.0,<=0.18.1")
|
||||
|
||||
@@ -490,6 +490,14 @@ class FinetuningArguments(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use the DFT loss."},
|
||||
)
|
||||
use_asft_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use the ASFT loss."},
|
||||
)
|
||||
asft_alpha: float = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The alpha parameter for ASFT loss to control the power of adaptive weight."},
|
||||
)
|
||||
use_eaft_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use the EAFT loss."},
|
||||
|
||||
@@ -33,7 +33,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
|
||||
from ..extras import logging
|
||||
from ..extras.constants import CHECKPOINT_NAMES, EngineName
|
||||
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
|
||||
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than
|
||||
from ..extras.packages import is_mcore_adapter_available
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
@@ -100,6 +100,52 @@ def _parse_args(
|
||||
return tuple(parsed_args)
|
||||
|
||||
|
||||
def _verify_trackio_args(training_args: "TrainingArguments") -> None:
|
||||
"""Validates Trackio-specific arguments.
|
||||
|
||||
Args:
|
||||
training_args: TrainingArguments instance (not a dictionary)
|
||||
"""
|
||||
report_to = training_args.report_to
|
||||
if not report_to:
|
||||
return
|
||||
|
||||
if isinstance(report_to, str):
|
||||
report_to = [report_to]
|
||||
|
||||
if "trackio" not in report_to:
|
||||
return
|
||||
|
||||
# --- Enforce project (required by Trackio) ---
|
||||
if not training_args.project:
|
||||
raise ValueError("`--project` must be specified when using Trackio.")
|
||||
|
||||
# --- Validate trackio_space_id format ---
|
||||
space_id = training_args.trackio_space_id
|
||||
if space_id:
|
||||
if space_id != "trackio" and "/" not in space_id:
|
||||
logger.warning(
|
||||
f"trackio_space_id '{space_id}' should typically be in format "
|
||||
"'org/space' for Hugging Face Spaces deployment."
|
||||
)
|
||||
|
||||
# --- Inform about default project usage ---
|
||||
if training_args.project == "huggingface":
|
||||
logger.info(
|
||||
"Using default project name 'huggingface'. "
|
||||
"Consider setting a custom project name with --project "
|
||||
"for better organization."
|
||||
)
|
||||
|
||||
# --- Validate hub repo privacy flag ---
|
||||
if training_args.hub_private_repo:
|
||||
logger.info("Repository will be created as private on Hugging Face Hub.")
|
||||
|
||||
# --- Recommend run_name for experiment clarity ---
|
||||
if not training_args.run_name:
|
||||
logger.warning("Consider setting --run_name for better experiment tracking clarity.")
|
||||
|
||||
|
||||
def _set_transformers_logging() -> None:
|
||||
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
@@ -278,8 +324,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support lora reward model.")
|
||||
|
||||
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
|
||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||
if training_args.report_to and any(
|
||||
logger not in ("wandb", "tensorboard", "trackio", "none") for logger in training_args.report_to
|
||||
):
|
||||
raise ValueError("PPO only accepts wandb, tensorboard, or trackio logger.")
|
||||
|
||||
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||
@@ -346,12 +394,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
if model_args.use_kt and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("KTransformers is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if data_args.neat_packing and is_transformers_version_greater_than("4.53.0"):
|
||||
raise ValueError("Neat packing is incompatible with transformers>=4.53.0.")
|
||||
|
||||
_set_env_vars()
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
_verify_trackio_args(training_args)
|
||||
|
||||
if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||
@@ -421,7 +467,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
training_args.resume_from_checkpoint is None
|
||||
and training_args.do_train
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
|
||||
and can_resume_from_checkpoint
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
|
||||
@@ -77,6 +77,8 @@ def apply_liger_kernel(
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3 as apply_liger_kernel
|
||||
elif model_type == "qwen3_moe":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel
|
||||
elif model_type == "qwen3_next":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_next as apply_liger_kernel
|
||||
elif model_type == "gpt_oss":
|
||||
try:
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel
|
||||
|
||||
@@ -77,6 +77,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
|
||||
_set_z3_leaf_modules(model, [Glm4MoeMoE])
|
||||
|
||||
if model_type == "glm4_moe_lite":
|
||||
from transformers.models.glm4_moe_lite.modeling_glm4_moe_lite import Glm4MoeLiteMoE
|
||||
|
||||
_set_z3_leaf_modules(model, [Glm4MoeLiteMoE])
|
||||
|
||||
if model_type == "glm4v_moe":
|
||||
from transformers.models.glm4v_moe.modeling_glm4v_moe import Glm4vMoeTextMoE
|
||||
|
||||
@@ -137,6 +142,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
|
||||
_set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock])
|
||||
|
||||
if model_type == "qwen3_next":
|
||||
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
|
||||
|
||||
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
|
||||
|
||||
|
||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.moe_aux_loss_coef:
|
||||
|
||||
@@ -37,7 +37,6 @@
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -45,10 +44,6 @@ import torch.nn.functional as F
|
||||
from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -105,13 +100,3 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> tuple["torch.Tensor", "tor
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return indices, cu_seqlens, max_seqlen_in_batch
|
||||
|
||||
|
||||
def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.block_diag_attn:
|
||||
return
|
||||
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user