Compare commits
201 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0e88c5754f | ||
|
|
3fff875f99 | ||
|
|
e2d9ab3591 | ||
|
|
3db5cf44ea | ||
|
|
994b9089e9 | ||
|
|
4c1513a845 | ||
|
|
86e009b504 | ||
|
|
c1e1918db1 | ||
|
|
341225a405 | ||
|
|
8c93921952 | ||
|
|
45367105fc | ||
|
|
df71359069 | ||
|
|
a03d14a9a6 | ||
|
|
41d7ca395e | ||
|
|
757573bec1 | ||
|
|
16d655b119 | ||
|
|
f6483de197 | ||
|
|
da34411bf2 | ||
|
|
1891b64072 | ||
|
|
a14069acf8 | ||
|
|
0ea708c226 | ||
|
|
cb474c7b11 | ||
|
|
e4d11a117b | ||
|
|
68365045b4 | ||
|
|
502555b65d | ||
|
|
0bc52c0aae | ||
|
|
6bf2663b8e | ||
|
|
d337de668e | ||
|
|
ec372f91e9 | ||
|
|
20b1bd8c54 | ||
|
|
ee17741591 | ||
|
|
93a6925ec5 | ||
|
|
47405a8e8a | ||
|
|
54ba30c47f | ||
|
|
b92214f78b | ||
|
|
71e4404c0d | ||
|
|
5ab997d484 | ||
|
|
6e7048831b | ||
|
|
97cd932c19 | ||
|
|
dfc7a7d5cd | ||
|
|
27e13a8371 | ||
|
|
bf6ad1fbed | ||
|
|
bc71380b59 | ||
|
|
137c87ff60 | ||
|
|
485b8dc18b | ||
|
|
875f9078d1 | ||
|
|
d3bfcbd3af | ||
|
|
e36db692e7 | ||
|
|
460a40756c | ||
|
|
18057e14ef | ||
|
|
025c8fe302 | ||
|
|
446129ca7a | ||
|
|
834c4e8ad9 | ||
|
|
11d961cf3c | ||
|
|
00b93d8b2f | ||
|
|
281fd5bb89 | ||
|
|
cb10050cb9 | ||
|
|
2935c4cddb | ||
|
|
0d6ec70c6f | ||
|
|
74777b4ded | ||
|
|
5f2bd04799 | ||
|
|
9a1a5f9778 | ||
|
|
edc8aefa59 | ||
|
|
ee1c786a12 | ||
|
|
a3e4f2b716 | ||
|
|
6685f1fb9e | ||
|
|
c89ff328f6 | ||
|
|
c6f1bc65c0 | ||
|
|
0f43c61229 | ||
|
|
8567dab167 | ||
|
|
0517d7bee5 | ||
|
|
5bc0b9b31c | ||
|
|
3d219b91b9 | ||
|
|
a90c6306f8 | ||
|
|
60558388ec | ||
|
|
b29a7f8cd6 | ||
|
|
a1501591e8 | ||
|
|
1408aa078d | ||
|
|
5acaa476d6 | ||
|
|
8ac4f87c91 | ||
|
|
14d3001824 | ||
|
|
1ac9389ddc | ||
|
|
0b0e27c2f1 | ||
|
|
fd1199cce4 | ||
|
|
3c9eda8265 | ||
|
|
6622cdb43f | ||
|
|
49c28a7dab | ||
|
|
a42671c2d7 | ||
|
|
f17ab6ad92 | ||
|
|
ca548af2a2 | ||
|
|
579997688f | ||
|
|
e6ba7ef3e6 | ||
|
|
20fdf177e8 | ||
|
|
f0b01803ea | ||
|
|
f5c4841ff2 | ||
|
|
1e01283d81 | ||
|
|
2196448c21 | ||
|
|
96a81ce89d | ||
|
|
a715490c2a | ||
|
|
973cf8e980 | ||
|
|
4357e42391 | ||
|
|
884b49e662 | ||
|
|
38c94d2e9c | ||
|
|
67d2eb6b2a | ||
|
|
b670fb57db | ||
|
|
188b4be64d | ||
|
|
889c042ecd | ||
|
|
3c4f8eaa55 | ||
|
|
6a75d57060 | ||
|
|
fda2cf677b | ||
|
|
cfdf5a5a78 | ||
|
|
a1437c15f7 | ||
|
|
42e7489713 | ||
|
|
024760f866 | ||
|
|
46f0189e88 | ||
|
|
edc7498111 | ||
|
|
9103fdf866 | ||
|
|
95bf795de4 | ||
|
|
bf99223a80 | ||
|
|
9caf9b6f91 | ||
|
|
727c7b0dc6 | ||
|
|
13d184b280 | ||
|
|
12a91774b0 | ||
|
|
88018000ac | ||
|
|
f6eda1c35d | ||
|
|
a2ebdbc112 | ||
|
|
e930a42083 | ||
|
|
4b123f49cb | ||
|
|
556eca918d | ||
|
|
31fcd03f3c | ||
|
|
89d9dd5aa5 | ||
|
|
d1aad72826 | ||
|
|
8e5b4bddf4 | ||
|
|
5a7cb9af4e | ||
|
|
d1cda4ec68 | ||
|
|
8aaf1185a5 | ||
|
|
b46bd07119 | ||
|
|
08fa707085 | ||
|
|
72ba29d81a | ||
|
|
cf2dc4c444 | ||
|
|
d82d86e16d | ||
|
|
bde31d8600 | ||
|
|
e115d55585 | ||
|
|
daea86e047 | ||
|
|
a4f69d8914 | ||
|
|
98f382fda3 | ||
|
|
cd899734f3 | ||
|
|
f51b435bcf | ||
|
|
0f82a55305 | ||
|
|
9fd7a410bb | ||
|
|
98fb3d015a | ||
|
|
bfb2ad7c79 | ||
|
|
135bfbf7c1 | ||
|
|
c6b17ebc20 | ||
|
|
b55eb30474 | ||
|
|
cec2f1fc00 | ||
|
|
8367ec03a7 | ||
|
|
37013f8068 | ||
|
|
8360544d65 | ||
|
|
b5cdef43a1 | ||
|
|
2e5d521ed8 | ||
|
|
dbe35d52d1 | ||
|
|
8bcdb6f52c | ||
|
|
5cfcb8262e | ||
|
|
0b331a318b | ||
|
|
5d6cf55208 | ||
|
|
9a1ec19845 | ||
|
|
a79e93f335 | ||
|
|
abcb94a738 | ||
|
|
a4f2d5aa6f | ||
|
|
6b738d1c89 | ||
|
|
f4c518b370 | ||
|
|
d475dd3809 | ||
|
|
5675c47a01 | ||
|
|
16e950454e | ||
|
|
2926265a14 | ||
|
|
af2607de1a | ||
|
|
826d7808b4 | ||
|
|
4c89aca243 | ||
|
|
43a065bb07 | ||
|
|
4513a2cc75 | ||
|
|
f29c1ac6e5 | ||
|
|
05abe47c8b | ||
|
|
6c185a2c57 | ||
|
|
af2cb33bb2 | ||
|
|
f16a4a8264 | ||
|
|
b232552d42 | ||
|
|
0edccc11a5 | ||
|
|
b2f5c0e0db | ||
|
|
5f5d4c1923 | ||
|
|
a7d7f79855 | ||
|
|
fa3150548e | ||
|
|
c7479751e8 | ||
|
|
870a54ac84 | ||
|
|
12fcfc2b72 | ||
|
|
95ae30f678 | ||
|
|
7408e778ca | ||
|
|
ba303fd1aa | ||
|
|
dd7a1dbfae | ||
|
|
f91fe10985 | ||
|
|
c7ab302c69 |
@@ -4,10 +4,10 @@
|
||||
.venv
|
||||
cache
|
||||
data
|
||||
docker
|
||||
saves
|
||||
hf_cache
|
||||
output
|
||||
examples
|
||||
.dockerignore
|
||||
.gitattributes
|
||||
.gitignore
|
||||
Dockerfile
|
||||
|
||||
10
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
10
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -1,13 +1,19 @@
|
||||
name: "\U0001F41B Bug / Help"
|
||||
description: Create a report to help us improve the LLaMA Factory
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Issues included in **FAQs** or those with **insufficient** information may be closed without a response.
|
||||
包含在**常见问题**内或提供信息**不完整**的 issues 可能不会被回复。
|
||||
|
||||
- type: checkboxes
|
||||
id: reminder
|
||||
attributes:
|
||||
label: Reminder
|
||||
description: |
|
||||
Please ensure you have read the README carefully and searched the existing issues.
|
||||
请确保您已经认真阅读了 README 并且搜索过现有的 Issue。
|
||||
Please ensure you have read the README carefully and searched the existing issues (including FAQs).
|
||||
请确保您已经认真阅读了 README 并且搜索过现有的 issues(包括常见问题)。
|
||||
|
||||
options:
|
||||
- label: I have read the README and searched the existing issues.
|
||||
|
||||
15
.github/workflows/label_issue.yml
vendored
15
.github/workflows/label_issue.yml
vendored
@@ -9,9 +9,22 @@ jobs:
|
||||
label_issue:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
issues: write
|
||||
|
||||
steps:
|
||||
- env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
ISSUE_URL: ${{ github.event.issue.html_url }}
|
||||
ISSUE_TITLE: ${{ github.event.issue.title }}
|
||||
run: |
|
||||
gh issue edit $ISSUE_URL --add-label "pending"
|
||||
LABEL=pending
|
||||
NPU_KEYWORDS=(npu huawei ascend 华为 昇腾)
|
||||
ISSUE_TITLE_LOWER=$(echo $ISSUE_TITLE | tr '[:upper:]' '[:lower:]')
|
||||
for KEYWORD in ${NPU_KEYWORDS[@]}; do
|
||||
if [[ $ISSUE_TITLE_LOWER == *$KEYWORD* ]] && [[ $ISSUE_TITLE_LOWER != *input* ]]; then
|
||||
LABEL=pending,npu
|
||||
break
|
||||
fi
|
||||
done
|
||||
gh issue edit $ISSUE_URL --add-label $LABEL
|
||||
|
||||
8
.github/workflows/tests.yml
vendored
8
.github/workflows/tests.yml
vendored
@@ -20,6 +20,12 @@ jobs:
|
||||
tests:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
environment:
|
||||
name: tests
|
||||
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -34,7 +40,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install .[torch,dev]
|
||||
python -m pip install ".[torch,dev]"
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -160,6 +160,8 @@ cython_debug/
|
||||
.idea/
|
||||
|
||||
# custom .gitignore
|
||||
user.config
|
||||
saves/
|
||||
cache/
|
||||
config/
|
||||
saves/
|
||||
output/
|
||||
wandb/
|
||||
|
||||
11
CITATION.cff
11
CITATION.cff
@@ -12,12 +12,16 @@ authors:
|
||||
given-names: "Yanhan"
|
||||
- family-names: "Luo"
|
||||
given-names: "Zheyan"
|
||||
- family-names: "Feng"
|
||||
given-names: "Zhangchi"
|
||||
- family-names: "Ma"
|
||||
given-names: "Yongqiang"
|
||||
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
||||
url: "https://arxiv.org/abs/2403.13372"
|
||||
preferred-citation:
|
||||
type: article
|
||||
type: conference-paper
|
||||
conference:
|
||||
name: "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)"
|
||||
authors:
|
||||
- family-names: "Zheng"
|
||||
given-names: "Yaowei"
|
||||
@@ -29,9 +33,12 @@ preferred-citation:
|
||||
given-names: "Yanhan"
|
||||
- family-names: "Luo"
|
||||
given-names: "Zheyan"
|
||||
- family-names: "Feng"
|
||||
given-names: "Zhangchi"
|
||||
- family-names: "Ma"
|
||||
given-names: "Yongqiang"
|
||||
journal: "arXiv preprint arXiv:2403.13372"
|
||||
title: "LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models"
|
||||
url: "https://arxiv.org/abs/2403.13372"
|
||||
year: 2024
|
||||
publisher: "Association for Computational Linguistics"
|
||||
address: "Bangkok, Thailand"
|
||||
|
||||
47
Dockerfile
47
Dockerfile
@@ -1,47 +0,0 @@
|
||||
# Use the NVIDIA official image with PyTorch 2.3.0
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
|
||||
FROM nvcr.io/nvidia/pytorch:24.02-py3
|
||||
|
||||
# Define installation arguments
|
||||
ARG INSTALL_BNB=false
|
||||
ARG INSTALL_VLLM=false
|
||||
ARG INSTALL_DEEPSPEED=false
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app/
|
||||
RUN pip config set global.index-url $PIP_INDEX
|
||||
RUN python -m pip install --upgrade pip
|
||||
RUN python -m pip install -r requirements.txt
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
COPY . /app/
|
||||
|
||||
# Install the LLaMA Factory
|
||||
RUN EXTRA_PACKAGES="metrics"; \
|
||||
if [ "$INSTALL_BNB" = "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_VLLM" = "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_DEEPSPEED" = "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||
fi; \
|
||||
pip install -e .[$EXTRA_PACKAGES] && \
|
||||
pip uninstall -y transformer-engine flash-attn
|
||||
|
||||
# Set up volumes
|
||||
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
|
||||
|
||||
# Expose port 7860 for the LLaMA Board
|
||||
EXPOSE 7860
|
||||
|
||||
# Expose port 8000 for the API service
|
||||
EXPOSE 8000
|
||||
|
||||
# Launch LLaMA Board
|
||||
CMD [ "llamafactory-cli", "webui" ]
|
||||
191
README.md
191
README.md
@@ -4,7 +4,7 @@
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llamafactory/)
|
||||
[](#projects-using-llama-factory)
|
||||
[](#projects-using-llama-factory)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
[](https://trendshift.io/repositories/4535)
|
||||
|
||||
👋 Join our [WeChat](assets/wechat.jpg).
|
||||
👋 Join our [WeChat](assets/wechat.jpg) or [NPU user group](assets/wechat_npu.jpg).
|
||||
|
||||
\[ English | [中文](README_zh.md) \]
|
||||
|
||||
@@ -48,7 +48,7 @@ Choose your path:
|
||||
|
||||
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
|
||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||
- **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ.
|
||||
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
|
||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||
@@ -151,35 +151,32 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Model size | Template |
|
||||
| --------------------------------------------------------- | -------------------------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
||||
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
| Model | Model size | Template |
|
||||
| ------------------------------------------------------------ | -------------------------------- | --------- |
|
||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||
@@ -341,7 +338,7 @@ cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]"
|
||||
```
|
||||
|
||||
Extra dependencies available: torch, torch_npu, metrics, deepspeed, bitsandbytes, vllm, galore, badam, gptq, awq, aqlm, qwen, modelscope, quality
|
||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, qwen, modelscope, quality
|
||||
|
||||
> [!TIP]
|
||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
||||
@@ -360,9 +357,7 @@ To enable FlashAttention-2 on the Windows platform, you need to install the prec
|
||||
|
||||
<details><summary>For Ascend NPU users</summary>
|
||||
|
||||
Join [NPU user group](assets/wechat_npu.jpg).
|
||||
|
||||
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e '.[torch-npu,metrics]'`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||
To install LLaMA Factory on Ascend NPU devices, please specify extra dependencies: `pip install -e ".[torch-npu,metrics]"`. Additionally, you need to install the **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**. Please follow the [installation tutorial](https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/softwareinstall/instg/atlasdeploy_03_0031.html) or use the following commands:
|
||||
|
||||
```bash
|
||||
# replace the url according to your CANN version and devices
|
||||
@@ -385,15 +380,12 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
|
||||
Docker image:
|
||||
|
||||
- 32GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
|
||||
- 64GB: [Download page](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||
|
||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
||||
|
||||
If you cannot infer model on NPU devices, try setting `do_sample: false` in the configurations.
|
||||
|
||||
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||
|
||||
</details>
|
||||
|
||||
### Data Preparation
|
||||
@@ -426,18 +418,38 @@ llamafactory-cli webui
|
||||
|
||||
### Build Docker
|
||||
|
||||
#### Use Docker
|
||||
For CUDA users:
|
||||
|
||||
```bash
|
||||
docker build -f ./Dockerfile \
|
||||
cd docker/docker-cuda/
|
||||
docker-compose up -d
|
||||
docker-compose exec llamafactory bash
|
||||
```
|
||||
|
||||
For Ascend NPU users:
|
||||
|
||||
```bash
|
||||
cd docker/docker-npu/
|
||||
docker-compose up -d
|
||||
docker-compose exec llamafactory bash
|
||||
```
|
||||
|
||||
<details><summary>Build without Docker Compose</summary>
|
||||
|
||||
For CUDA users:
|
||||
|
||||
```bash
|
||||
docker build -f ./docker/docker-cuda/Dockerfile \
|
||||
--build-arg INSTALL_BNB=false \
|
||||
--build-arg INSTALL_VLLM=false \
|
||||
--build-arg INSTALL_DEEPSPEED=false \
|
||||
--build-arg INSTALL_FLASHATTN=false \
|
||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||
-t llamafactory:latest .
|
||||
|
||||
docker run -it --gpus=all \
|
||||
-v ./hf_cache:/root/.cache/huggingface/ \
|
||||
docker run -dit --gpus=all \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-p 7860:7860 \
|
||||
@@ -445,15 +457,44 @@ docker run -it --gpus=all \
|
||||
--shm-size 16G \
|
||||
--name llamafactory \
|
||||
llamafactory:latest
|
||||
|
||||
docker exec -it llamafactory bash
|
||||
```
|
||||
|
||||
#### Use Docker Compose
|
||||
For Ascend NPU users:
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
docker-compose exec llamafactory bash
|
||||
# Choose docker image upon your environment
|
||||
docker build -f ./docker/docker-npu/Dockerfile \
|
||||
--build-arg INSTALL_DEEPSPEED=false \
|
||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||
-t llamafactory:latest .
|
||||
|
||||
# Change `device` upon your resources
|
||||
docker run -dit \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
||||
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
|
||||
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
||||
-p 7860:7860 \
|
||||
-p 8000:8000 \
|
||||
--device /dev/davinci0 \
|
||||
--device /dev/davinci_manager \
|
||||
--device /dev/devmm_svm \
|
||||
--device /dev/hisi_hdc \
|
||||
--shm-size 16G \
|
||||
--name llamafactory \
|
||||
llamafactory:latest
|
||||
|
||||
docker exec -it llamafactory bash
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>Details about volume</summary>
|
||||
|
||||
- hf_cache: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
||||
@@ -503,38 +544,63 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
||||
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
||||
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
||||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
||||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
||||
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
||||
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
||||
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
||||
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
||||
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||
1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
|
||||
1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
|
||||
1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
|
||||
1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
|
||||
1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
|
||||
1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
|
||||
1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
|
||||
1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
|
||||
1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
|
||||
1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
|
||||
1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
|
||||
1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
|
||||
1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
|
||||
1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
|
||||
1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
|
||||
1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
|
||||
1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
|
||||
1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
|
||||
1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
|
||||
1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
|
||||
1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
|
||||
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
|
||||
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
|
||||
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
|
||||
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
@@ -542,6 +608,9 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
||||
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
||||
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
||||
|
||||
</details>
|
||||
|
||||
@@ -549,17 +618,19 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## Citation
|
||||
|
||||
If this work is helpful, please kindly cite as:
|
||||
|
||||
```bibtex
|
||||
@article{zheng2024llamafactory,
|
||||
@inproceedings{zheng2024llamafactory,
|
||||
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
||||
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
|
||||
journal={arXiv preprint arXiv:2403.13372},
|
||||
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
|
||||
booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
|
||||
address={Bangkok, Thailand},
|
||||
publisher={Association for Computational Linguistics},
|
||||
year={2024},
|
||||
url={http://arxiv.org/abs/2403.13372}
|
||||
}
|
||||
|
||||
193
README_zh.md
193
README_zh.md
@@ -4,7 +4,7 @@
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llamafactory/)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
[](https://trendshift.io/repositories/4535)
|
||||
|
||||
👋 加入我们的[微信群](assets/wechat.jpg)。
|
||||
👋 加入我们的[微信群](assets/wechat.jpg)或 [NPU 用户群](assets/wechat_npu.jpg)。
|
||||
|
||||
\[ [English](README.md) | 中文 \]
|
||||
|
||||
@@ -48,7 +48,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||
- **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
|
||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||
- **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。
|
||||
- **先进算法**:GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
|
||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||
@@ -151,35 +151,32 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
|
||||
|
||||
## 模型
|
||||
|
||||
| 模型名 | 模型大小 | Template |
|
||||
| --------------------------------------------------------- | -------------------------------- | --------- |
|
||||
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
|
||||
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
|
||||
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
|
||||
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
| 模型名 | 模型大小 | Template |
|
||||
| ------------------------------------------------------------ | -------------------------------- | --------- |
|
||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | gemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Qwen/Qwen1.5/Qwen2 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||
@@ -341,7 +338,7 @@ cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]"
|
||||
```
|
||||
|
||||
可选的额外依赖项:torch、torch_npu、metrics、deepspeed、bitsandbytes、vllm、galore、badam、gptq、awq、aqlm、qwen、modelscope、quality
|
||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、qwen、modelscope、quality
|
||||
|
||||
> [!TIP]
|
||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
||||
@@ -360,9 +357,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
||||
|
||||
<details><summary>昇腾 NPU 用户指南</summary>
|
||||
|
||||
加入 [NPU 用户群](assets/wechat_npu.jpg)。
|
||||
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e '.[torch-npu,metrics]'` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit and Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||
|
||||
```bash
|
||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
||||
@@ -385,15 +380,12 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
|
||||
Docker 镜像:
|
||||
|
||||
- 32GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html)
|
||||
- 64GB:[下载地址](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||
|
||||
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
||||
|
||||
如果遇到无法正常推理的情况,请尝试设置 `do_sample: false`。
|
||||
|
||||
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||
|
||||
</details>
|
||||
|
||||
### 数据准备
|
||||
@@ -426,18 +418,38 @@ llamafactory-cli webui
|
||||
|
||||
### 构建 Docker
|
||||
|
||||
#### 使用 Docker
|
||||
CUDA 用户:
|
||||
|
||||
```bash
|
||||
docker build -f ./Dockerfile \
|
||||
cd docker/docker-cuda/
|
||||
docker-compose up -d
|
||||
docker-compose exec llamafactory bash
|
||||
```
|
||||
|
||||
昇腾 NPU 用户:
|
||||
|
||||
```bash
|
||||
cd docker/docker-npu/
|
||||
docker-compose up -d
|
||||
docker-compose exec llamafactory bash
|
||||
```
|
||||
|
||||
<details><summary>不使用 Docker Compose 构建</summary>
|
||||
|
||||
CUDA 用户:
|
||||
|
||||
```bash
|
||||
docker build -f ./docker/docker-cuda/Dockerfile \
|
||||
--build-arg INSTALL_BNB=false \
|
||||
--build-arg INSTALL_VLLM=false \
|
||||
--build-arg INSTALL_DEEPSPEED=false \
|
||||
--build-arg INSTALL_FLASHATTN=false \
|
||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||
-t llamafactory:latest .
|
||||
|
||||
docker run -it --gpus=all \
|
||||
-v ./hf_cache:/root/.cache/huggingface/ \
|
||||
docker run -dit --gpus=all \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-p 7860:7860 \
|
||||
@@ -445,15 +457,44 @@ docker run -it --gpus=all \
|
||||
--shm-size 16G \
|
||||
--name llamafactory \
|
||||
llamafactory:latest
|
||||
|
||||
docker exec -it llamafactory bash
|
||||
```
|
||||
|
||||
#### 使用 Docker Compose
|
||||
昇腾 NPU 用户:
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
docker-compose exec llamafactory bash
|
||||
# 根据您的环境选择镜像
|
||||
docker build -f ./docker/docker-npu/Dockerfile \
|
||||
--build-arg INSTALL_DEEPSPEED=false \
|
||||
--build-arg PIP_INDEX=https://pypi.org/simple \
|
||||
-t llamafactory:latest .
|
||||
|
||||
# 根据您的资源更改 `device`
|
||||
docker run -dit \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
||||
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
|
||||
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
||||
-p 7860:7860 \
|
||||
-p 8000:8000 \
|
||||
--device /dev/davinci0 \
|
||||
--device /dev/davinci_manager \
|
||||
--device /dev/devmm_svm \
|
||||
--device /dev/hisi_hdc \
|
||||
--shm-size 16G \
|
||||
--name llamafactory \
|
||||
llamafactory:latest
|
||||
|
||||
docker exec -it llamafactory bash
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>数据卷详情</summary>
|
||||
|
||||
- hf_cache:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
||||
@@ -503,38 +544,63 @@ run_name: test_run # 可选
|
||||
1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526)
|
||||
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. KDD 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
||||
1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. ACL 2024. [[arxiv]](https://arxiv.org/abs/2402.15043)
|
||||
1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333)
|
||||
1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419)
|
||||
1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228)
|
||||
1. Wu et al. Large Language Models are Parallel Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2403.09073)
|
||||
1. Zhang et al. EDT: Improving Large Language Models' Generation by Entropy-based Dynamic Temperature Sampling. 2024. [[arxiv]](https://arxiv.org/abs/2403.14541)
|
||||
1. Weller et al. FollowIR: Evaluating and Teaching Information Retrieval Models to Follow Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2403.15246)
|
||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||
1. Hongbin Na. CBT-LLM: A Chinese Large Language Model for Cognitive Behavioral Therapy-based Mental Health Question Answering. COLING 2024. [[arxiv]](https://arxiv.org/abs/2403.16008)
|
||||
1. Zan et al. CodeS: Natural Language to Code Repository via Multi-Layer Sketch. 2024. [[arxiv]](https://arxiv.org/abs/2403.16443)
|
||||
1. Liu et al. Extensive Self-Contrast Enables Feedback-Free Language Model Alignment. 2024. [[arxiv]](https://arxiv.org/abs/2404.00604)
|
||||
1. Luo et al. BAdam: A Memory Efficient Full Parameter Training Method for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.02827)
|
||||
1. Du et al. Chinese Tiny LLM: Pretraining a Chinese-Centric Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2404.04167)
|
||||
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
||||
1. Ma et al. Parameter Efficient Quasi-Orthogonal Fine-Tuning via Givens Rotation. ICML 2024. [[arxiv]](https://arxiv.org/abs/2404.04316)
|
||||
1. Liu et al. Dynamic Generation of Personalities with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.07084)
|
||||
1. Shang et al. How Far Have We Gone in Stripped Binary Code Understanding Using Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.09836)
|
||||
1. Huang et al. LLMTune: Accelerate Database Knob Tuning with Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2404.11581)
|
||||
1. Deng et al. Text-Tuple-Table: Towards Information Integration in Text-to-Table Generation via Global Tuple Extraction. 2024. [[arxiv]](https://arxiv.org/abs/2404.14215)
|
||||
1. Acikgoz et al. Hippocrates: An Open-Source Framework for Advancing Large Language Models in Healthcare. 2024. [[arxiv]](https://arxiv.org/abs/2404.16621)
|
||||
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||
1. Zhang et al. Small Language Models Need Strong Verifiers to Self-Correct Reasoning. ACL 2024 Findings. [[arxiv]](https://arxiv.org/abs/2404.17140)
|
||||
1. Zhou et al. FREB-TQA: A Fine-Grained Robustness Evaluation Benchmark for Table Question Answering. NAACL 2024. [[arxiv]](https://arxiv.org/abs/2404.18585)
|
||||
1. Xu et al. Large Language Models for Cyber Security: A Systematic Literature Review. 2024. [[arxiv]](https://arxiv.org/abs/2405.04760)
|
||||
1. Dammu et al. "They are uncultured": Unveiling Covert Harms and Social Threats in LLM Generated Conversations. 2024. [[arxiv]](https://arxiv.org/abs/2405.05378)
|
||||
1. Yi et al. A safety realignment framework via subspace-oriented model fusion for large language models. 2024. [[arxiv]](https://arxiv.org/abs/2405.09055)
|
||||
1. Lou et al. SPO: Multi-Dimensional Preference Sequential Alignment With Implicit Reward Modeling. 2024. [[arxiv]](https://arxiv.org/abs/2405.12739)
|
||||
1. Zhang et al. Getting More from Less: Large Language Models are Good Spontaneous Multilingual Learners. 2024. [[arxiv]](https://arxiv.org/abs/2405.13816)
|
||||
1. Zhang et al. TS-Align: A Teacher-Student Collaborative Framework for Scalable Iterative Finetuning of Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2405.20215)
|
||||
1. Zihong Chen. Sentence Segmentation and Sentence Punctuation Based on XunziALLM. 2024. [[paper]](https://aclanthology.org/2024.lt4hala-1.30)
|
||||
1. Gao et al. The Best of Both Worlds: Toward an Honest and Helpful Large Language Model. 2024. [[arxiv]](https://arxiv.org/abs/2406.00380)
|
||||
1. Wang and Song. MARS: Benchmarking the Metaphysical Reasoning Abilities of Language Models with a Multi-task Evaluation Dataset. 2024. [[arxiv]](https://arxiv.org/abs/2406.02106)
|
||||
1. Hu et al. Computational Limits of Low-Rank Adaptation (LoRA) for Transformer-Based Models. 2024. [[arxiv]](https://arxiv.org/abs/2406.03136)
|
||||
1. Ge et al. Time Sensitive Knowledge Editing through Efficient Finetuning. ACL 2024. [[arxiv]](https://arxiv.org/abs/2406.04496)
|
||||
1. Tan et al. Peer Review as A Multi-Turn and Long-Context Dialogue with Role-Based Interactions. 2024. [[arxiv]](https://arxiv.org/abs/2406.05688)
|
||||
1. Song et al. Turbo Sparse: Achieving LLM SOTA Performance with Minimal Activated Parameters. 2024. [[arxiv]](https://arxiv.org/abs/2406.05955)
|
||||
1. Gu et al. RWKV-CLIP: A Robust Vision-Language Representation Learner. 2024. [[arxiv]](https://arxiv.org/abs/2406.06973)
|
||||
1. Chen et al. Advancing Tool-Augmented Large Language Models: Integrating Insights from Errors in Inference Trees. 2024. [[arxiv]](https://arxiv.org/abs/2406.07115)
|
||||
1. Zhu et al. Are Large Language Models Good Statisticians?. 2024. [[arxiv]](https://arxiv.org/abs/2406.07815)
|
||||
1. Li et al. Know the Unknown: An Uncertainty-Sensitive Method for LLM Instruction Tuning. 2024. [[arxiv]](https://arxiv.org/abs/2406.10099)
|
||||
1. Ding et al. IntentionQA: A Benchmark for Evaluating Purchase Intention Comprehension Abilities of Language Models in E-commerce. 2024. [[arxiv]](https://arxiv.org/abs/2406.10173)
|
||||
1. He et al. COMMUNITY-CROSS-INSTRUCT: Unsupervised Instruction Generation for Aligning Large Language Models to Online Communities. 2024. [[arxiv]](https://arxiv.org/abs/2406.12074)
|
||||
1. Lin et al. FVEL: Interactive Formal Verification Environment with Large Language Models via Theorem Proving. 2024. [[arxiv]](https://arxiv.org/abs/2406.14408)
|
||||
1. Treutlein et al. Connecting the Dots: LLMs can Infer and Verbalize Latent Structure from Disparate Training Data. 2024. [[arxiv]](https://arxiv.org/abs/2406.14546)
|
||||
1. Feng et al. SS-Bench: A Benchmark for Social Story Generation and Evaluation. 2024. [[arxiv]](https://arxiv.org/abs/2406.15695)
|
||||
1. Feng et al. Self-Constructed Context Decompilation with Fined-grained Alignment Enhancement. 2024. [[arxiv]](https://arxiv.org/abs/2406.17233)
|
||||
1. Liu et al. Large Language Models for Cuffless Blood Pressure Measurement From Wearable Biosignals. 2024. [[arxiv]](https://arxiv.org/abs/2406.18069)
|
||||
1. Iyer et al. Exploring Very Low-Resource Translation with LLMs: The University of Edinburgh’s Submission to AmericasNLP 2024 Translation Task. AmericasNLP 2024. [[paper]](https://aclanthology.org/2024.americasnlp-1.25)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
@@ -542,6 +608,9 @@ run_name: test_run # 可选
|
||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
||||
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
|
||||
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
|
||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
|
||||
|
||||
</details>
|
||||
|
||||
@@ -549,17 +618,19 @@ run_name: test_run # 可选
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command-R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [LLaMA-3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## 引用
|
||||
|
||||
如果您觉得此项目有帮助,请考虑以下列格式引用
|
||||
|
||||
```bibtex
|
||||
@article{zheng2024llamafactory,
|
||||
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
||||
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Yongqiang Ma},
|
||||
journal={arXiv preprint arXiv:2403.13372},
|
||||
@inproceedings{zheng2024llamafactory,
|
||||
title={LlamaFactory: Unified Efficient Fine-Tuning of 100+ Language Models},
|
||||
author={Yaowei Zheng and Richong Zhang and Junhao Zhang and Yanhan Ye and Zheyan Luo and Zhangchi Feng and Yongqiang Ma},
|
||||
booktitle={Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 3: System Demonstrations)},
|
||||
address={Bangkok, Thailand},
|
||||
publisher={Association for Computational Linguistics},
|
||||
year={2024},
|
||||
url={http://arxiv.org/abs/2403.13372}
|
||||
}
|
||||
|
||||
@@ -11,8 +11,9 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
|
||||
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
||||
"ranking": "whether the dataset is a preference dataset or not. (default: False)",
|
||||
"subset": "the name of the subset. (optional, default: None)",
|
||||
"split": "the name of dataset split to be used. (optional, default: train)",
|
||||
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
|
||||
"num_samples": "the number of samples in the dataset used for training. (optional, default: None)",
|
||||
"num_samples": "the number of samples in the dataset to be used. (optional, default: None)",
|
||||
"columns (optional)": {
|
||||
"prompt": "the column name in the dataset containing the prompts. (default: instruction)",
|
||||
"query": "the column name in the dataset containing the queries. (default: input)",
|
||||
|
||||
@@ -11,8 +11,9 @@
|
||||
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||
"subset": "数据集子集的名称(可选,默认:None)",
|
||||
"split": "所使用的数据集切分(可选,默认:train)",
|
||||
"folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
|
||||
"num_samples": "该数据集中用于训练的样本数量。(可选,默认:None)",
|
||||
"num_samples": "该数据集所使用的样本数量。(可选,默认:None)",
|
||||
"columns(可选)": {
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||
"query": "数据集代表请求的表头名称(默认:input)",
|
||||
|
||||
59
docker/docker-cuda/Dockerfile
Normal file
59
docker/docker-cuda/Dockerfile
Normal file
@@ -0,0 +1,59 @@
|
||||
# Use the NVIDIA official image with PyTorch 2.3.0
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
|
||||
FROM nvcr.io/nvidia/pytorch:24.02-py3
|
||||
|
||||
# Define environments
|
||||
ENV MAX_JOBS=4
|
||||
ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
|
||||
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
|
||||
# Define installation arguments
|
||||
ARG INSTALL_BNB=false
|
||||
ARG INSTALL_VLLM=false
|
||||
ARG INSTALL_DEEPSPEED=false
|
||||
ARG INSTALL_FLASHATTN=false
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app
|
||||
RUN pip config set global.index-url "$PIP_INDEX" && \
|
||||
pip config set global.extra-index-url "$PIP_INDEX" && \
|
||||
python -m pip install --upgrade pip && \
|
||||
python -m pip install -r requirements.txt
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install the LLaMA Factory
|
||||
RUN EXTRA_PACKAGES="metrics"; \
|
||||
if [ "$INSTALL_BNB" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},bitsandbytes"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_VLLM" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},vllm"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||
fi; \
|
||||
pip install -e ".[$EXTRA_PACKAGES]"
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN pip uninstall -y transformer-engine flash-attn && \
|
||||
if [ "$INSTALL_FLASHATTN" == "true" ]; then \
|
||||
pip uninstall -y ninja && pip install ninja && \
|
||||
pip install --no-cache-dir flash-attn --no-build-isolation; \
|
||||
fi
|
||||
|
||||
# Set up volumes
|
||||
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
|
||||
|
||||
# Expose port 7860 for the LLaMA Board
|
||||
ENV GRADIO_SERVER_PORT 7860
|
||||
EXPOSE 7860
|
||||
|
||||
# Expose port 8000 for the API service
|
||||
ENV API_PORT 8000
|
||||
EXPOSE 8000
|
||||
@@ -1,18 +1,20 @@
|
||||
services:
|
||||
llamafactory:
|
||||
build:
|
||||
dockerfile: Dockerfile
|
||||
context: .
|
||||
dockerfile: ./docker/docker-cuda/Dockerfile
|
||||
context: ../..
|
||||
args:
|
||||
INSTALL_BNB: false
|
||||
INSTALL_VLLM: false
|
||||
INSTALL_DEEPSPEED: false
|
||||
INSTALL_FLASHATTN: false
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
container_name: llamafactory
|
||||
volumes:
|
||||
- ./hf_cache:/root/.cache/huggingface/
|
||||
- ./data:/app/data
|
||||
- ./output:/app/output
|
||||
- ../../hf_cache:/root/.cache/huggingface
|
||||
- ../../ms_cache:/root/.cache/modelscope
|
||||
- ../../data:/app/data
|
||||
- ../../output:/app/output
|
||||
ports:
|
||||
- "7860:7860"
|
||||
- "8000:8000"
|
||||
45
docker/docker-npu/Dockerfile
Normal file
45
docker/docker-npu/Dockerfile
Normal file
@@ -0,0 +1,45 @@
|
||||
# Use the Ubuntu 22.04 image with CANN 8.0.rc1
|
||||
# More versions can be found at https://hub.docker.com/r/cosdt/cann/tags
|
||||
# FROM cosdt/cann:8.0.rc1-910-ubuntu22.04
|
||||
FROM cosdt/cann:8.0.rc1-910b-ubuntu22.04
|
||||
# FROM cosdt/cann:8.0.rc1-910-openeuler22.03
|
||||
# FROM cosdt/cann:8.0.rc1-910b-openeuler22.03
|
||||
|
||||
# Define environments
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Define installation arguments
|
||||
ARG INSTALL_DEEPSPEED=false
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
ARG TORCH_INDEX=https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Set the working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app
|
||||
RUN pip config set global.index-url "$PIP_INDEX" && \
|
||||
pip config set global.extra-index-url "$TORCH_INDEX" && \
|
||||
python -m pip install --upgrade pip && \
|
||||
python -m pip install -r requirements.txt
|
||||
|
||||
# Copy the rest of the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Install the LLaMA Factory
|
||||
RUN EXTRA_PACKAGES="torch-npu,metrics"; \
|
||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||
fi; \
|
||||
pip install -e ".[$EXTRA_PACKAGES]"
|
||||
|
||||
# Set up volumes
|
||||
VOLUME [ "/root/.cache/huggingface", "/root/.cache/modelscope", "/app/data", "/app/output" ]
|
||||
|
||||
# Expose port 7860 for the LLaMA Board
|
||||
ENV GRADIO_SERVER_PORT 7860
|
||||
EXPOSE 7860
|
||||
|
||||
# Expose port 8000 for the API service
|
||||
ENV API_PORT 8000
|
||||
EXPOSE 8000
|
||||
31
docker/docker-npu/docker-compose.yml
Normal file
31
docker/docker-npu/docker-compose.yml
Normal file
@@ -0,0 +1,31 @@
|
||||
services:
|
||||
llamafactory:
|
||||
build:
|
||||
dockerfile: ./docker/docker-npu/Dockerfile
|
||||
context: ../..
|
||||
args:
|
||||
INSTALL_DEEPSPEED: false
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
container_name: llamafactory
|
||||
volumes:
|
||||
- ../../hf_cache:/root/.cache/huggingface
|
||||
- ../../ms_cache:/root/.cache/modelscope
|
||||
- ../../data:/app/data
|
||||
- ../../output:/app/output
|
||||
- /usr/local/dcmi:/usr/local/dcmi
|
||||
- /usr/local/bin/npu-smi:/usr/local/bin/npu-smi
|
||||
- /usr/local/Ascend/driver:/usr/local/Ascend/driver
|
||||
- /etc/ascend_install.info:/etc/ascend_install.info
|
||||
ports:
|
||||
- "7860:7860"
|
||||
- "8000:8000"
|
||||
ipc: host
|
||||
tty: true
|
||||
stdin_open: true
|
||||
command: bash
|
||||
devices:
|
||||
- /dev/davinci0
|
||||
- /dev/davinci_manager
|
||||
- /dev/devmm_svm
|
||||
- /dev/hisi_hdc
|
||||
restart: unless-stopped
|
||||
@@ -94,10 +94,10 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.
|
||||
|
||||
### QLoRA Fine-Tuning
|
||||
|
||||
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
|
||||
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes/HQQ/EETQ Quantization (Recommended)
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
||||
|
||||
@@ -94,10 +94,10 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.
|
||||
|
||||
### QLoRA 微调
|
||||
|
||||
#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
|
||||
#### 基于 4/8 比特 Bitsandbytes/HQQ/EETQ 量化进行指令监督微调(推荐)
|
||||
|
||||
```bash
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
|
||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||
```
|
||||
|
||||
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
||||
|
||||
@@ -6,6 +6,7 @@ stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
use_badam: true
|
||||
badam_mode: layer
|
||||
badam_switch_mode: ascending
|
||||
badam_switch_interval: 50
|
||||
badam_verbose: 2
|
||||
@@ -32,7 +33,6 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
pure_bf16: true
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
42
examples/extras/badam/llama3_full_sft_ds3.yaml
Normal file
42
examples/extras/badam/llama3_full_sft_ds3.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
use_badam: true
|
||||
badam_mode: layer
|
||||
badam_switch_mode: ascending
|
||||
badam_switch_interval: 50
|
||||
badam_verbose: 2
|
||||
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/full/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
||||
@@ -30,7 +30,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -31,7 +31,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -30,7 +30,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -7,7 +7,7 @@ do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: all
|
||||
pissa_init: true
|
||||
pissa_iter: 4
|
||||
pissa_iter: 16
|
||||
pissa_convert: true
|
||||
|
||||
### dataset
|
||||
@@ -32,7 +32,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
3
examples/inference/llava1_5.yaml
Normal file
3
examples/inference/llava1_5.yaml
Normal file
@@ -0,0 +1,3 @@
|
||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
||||
template: vicuna
|
||||
visual_inputs: true
|
||||
@@ -29,7 +29,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -7,7 +7,7 @@ do_train: true
|
||||
finetuning_type: lora
|
||||
lora_target: all
|
||||
pref_beta: 0.1
|
||||
pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo]
|
||||
pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
||||
|
||||
### dataset
|
||||
dataset: dpo_en_demo
|
||||
@@ -31,7 +31,7 @@ learning_rate: 5.0e-6
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -6,8 +6,7 @@ adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||
finetuning_type: lora
|
||||
|
||||
### dataset
|
||||
task: mmlu
|
||||
split: test
|
||||
task: mmlu_test # choices: [mmlu_test, ceval_validation, cmmlu_test]
|
||||
template: fewshot
|
||||
lang: en
|
||||
n_shot: 5
|
||||
|
||||
@@ -30,7 +30,7 @@ learning_rate: 5.0e-6
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -30,7 +30,7 @@ learning_rate: 1.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### generate
|
||||
|
||||
@@ -8,7 +8,7 @@ do_predict: true
|
||||
finetuning_type: lora
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
eval_dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
max_samples: 50
|
||||
|
||||
@@ -15,7 +15,7 @@ overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/lora/sft
|
||||
output_dir: saves/llama3-8b/lora/pretrain
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
@@ -28,7 +28,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -25,11 +25,11 @@ overwrite_output_dir: true
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 8
|
||||
learning_rate: 1.0e-5
|
||||
learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -29,7 +29,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -30,7 +30,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -30,7 +30,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -30,7 +30,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -29,7 +29,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -29,7 +29,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -29,7 +29,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
quantization_bit: 4
|
||||
quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)]
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
@@ -30,7 +31,7 @@ learning_rate: 1.0e-4
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
fp16: true
|
||||
bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
@@ -18,3 +18,4 @@ matplotlib>=3.7.0
|
||||
fire
|
||||
packaging
|
||||
pyyaml
|
||||
numpy<2.0.0
|
||||
|
||||
@@ -44,6 +44,7 @@ def calculate_lr(
|
||||
template: str = "default",
|
||||
cutoff_len: int = 1024, # i.e. maximum input length during training
|
||||
is_mistral: bool = False, # mistral model uses a smaller learning rate,
|
||||
packing: bool = False,
|
||||
):
|
||||
r"""
|
||||
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||
@@ -57,19 +58,21 @@ def calculate_lr(
|
||||
dataset_dir=dataset_dir,
|
||||
template=template,
|
||||
cutoff_len=cutoff_len,
|
||||
packing=packing,
|
||||
output_dir="dummy_dir",
|
||||
overwrite_cache=True,
|
||||
do_train=True,
|
||||
)
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
|
||||
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
|
||||
if stage == "pt":
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
elif stage == "sft":
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
||||
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
valid_tokens, total_tokens = 0, 0
|
||||
|
||||
@@ -83,11 +83,12 @@ def cal_ppl(
|
||||
train_on_prompt=train_on_prompt,
|
||||
output_dir="dummy_dir",
|
||||
overwrite_cache=True,
|
||||
do_train=True,
|
||||
)
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)
|
||||
trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"]
|
||||
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False)
|
||||
if stage == "pt":
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
@@ -98,7 +99,7 @@ def cal_ppl(
|
||||
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
||||
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
@@ -44,10 +44,11 @@ def length_cdf(
|
||||
cutoff_len=1_000_000,
|
||||
output_dir="dummy_dir",
|
||||
overwrite_cache=True,
|
||||
do_train=True,
|
||||
)
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)["train_dataset"]
|
||||
total_num = len(trainset)
|
||||
length_dict = defaultdict(int)
|
||||
for sample in tqdm(trainset["input_ids"]):
|
||||
|
||||
@@ -36,15 +36,19 @@ def quantize_loftq(
|
||||
lora_alpha: int = None,
|
||||
lora_rank: int = 16,
|
||||
lora_dropout: float = 0,
|
||||
lora_target: str = "q_proj,v_proj",
|
||||
lora_target: tuple = ("q_proj", "v_proj"),
|
||||
save_safetensors: bool = True,
|
||||
):
|
||||
r"""
|
||||
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
|
||||
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
|
||||
"""
|
||||
if isinstance(lora_target, str):
|
||||
lora_target = [name.strip() for name in lora_target.split(",")]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
||||
|
||||
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
@@ -52,7 +56,7 @@ def quantize_loftq(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||
lora_dropout=lora_dropout,
|
||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
||||
target_modules=lora_target,
|
||||
init_lora_weights="loftq",
|
||||
loftq_config=loftq_config,
|
||||
)
|
||||
|
||||
@@ -35,21 +35,25 @@ def quantize_pissa(
|
||||
lora_alpha: int = None,
|
||||
lora_rank: int = 16,
|
||||
lora_dropout: float = 0,
|
||||
lora_target: str = "q_proj,v_proj",
|
||||
lora_target: tuple = ("q_proj", "v_proj"),
|
||||
save_safetensors: bool = True,
|
||||
):
|
||||
r"""
|
||||
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
|
||||
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
|
||||
"""
|
||||
if isinstance(lora_target, str):
|
||||
lora_target = [name.strip() for name in lora_target.split(",")]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||
lora_dropout=lora_dropout,
|
||||
target_modules=[name.strip() for name in lora_target.split(",")],
|
||||
target_modules=lora_target,
|
||||
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
|
||||
)
|
||||
|
||||
|
||||
10
setup.py
10
setup.py
@@ -39,12 +39,14 @@ extra_require = {
|
||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||
"deepspeed": ["deepspeed>=0.10.0"],
|
||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||
"vllm": ["vllm>=0.4.3"],
|
||||
"galore": ["galore-torch"],
|
||||
"badam": ["badam"],
|
||||
"gptq": ["optimum>=1.16.0", "auto-gptq>=0.5.0"],
|
||||
"hqq": ["hqq"],
|
||||
"eetq": ["eetq"],
|
||||
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
||||
"awq": ["autoawq"],
|
||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||
"vllm": ["vllm>=0.4.3"],
|
||||
"galore": ["galore-torch"],
|
||||
"badam": ["badam>=1.2.1"],
|
||||
"qwen": ["transformers_stream_generator"],
|
||||
"modelscope": ["modelscope"],
|
||||
"dev": ["ruff", "pytest"],
|
||||
|
||||
@@ -12,7 +12,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Level: api, webui > chat, eval, train > data, model > hparams > extras
|
||||
r"""
|
||||
Efficient fine-tuning of large language models.
|
||||
|
||||
Level:
|
||||
api, webui > chat, eval, train > data, model > hparams > extras
|
||||
|
||||
Dependency graph:
|
||||
main:
|
||||
transformers>=4.41.2
|
||||
datasets>=2.16.0
|
||||
accelerate>=0.30.1
|
||||
peft>=0.11.1
|
||||
trl>=0.8.6
|
||||
attention:
|
||||
transformers>=4.42.4 (gemma+fa2)
|
||||
longlora:
|
||||
transformers>=4.41.2,<=4.42.4
|
||||
packing:
|
||||
transformers>=4.41.2,<=4.42.4
|
||||
patcher:
|
||||
transformers==4.41.2 (chatglm)
|
||||
"""
|
||||
|
||||
from .cli import VERSION
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ def _process_request(
|
||||
|
||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||
tool_calls = [
|
||||
{"name": tool_call.function.name, "argument": tool_call.function.arguments}
|
||||
{"name": tool_call.function.name, "arguments": tool_call.function.arguments}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
content = json.dumps(tool_calls, ensure_ascii=False)
|
||||
|
||||
@@ -96,7 +96,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
tools: Optional[List[FunctionAvailable]] = None
|
||||
do_sample: bool = True
|
||||
do_sample: Optional[bool] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: int = 1
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||
|
||||
@@ -115,13 +116,11 @@ class ChatModel:
|
||||
|
||||
|
||||
def run_chat() -> None:
|
||||
try:
|
||||
import platform
|
||||
|
||||
if platform.system() != "Windows":
|
||||
if os.name != "nt":
|
||||
try:
|
||||
import readline # noqa: F401
|
||||
except ImportError:
|
||||
print("Install `readline` for a better experience.")
|
||||
except ImportError:
|
||||
print("Install `readline` for a better experience.")
|
||||
|
||||
chat_model = ChatModel()
|
||||
messages = []
|
||||
|
||||
@@ -54,7 +54,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
self.tokenizer = tokenizer_module["tokenizer"]
|
||||
self.processor = tokenizer_module["processor"]
|
||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
|
||||
self.model = load_model(
|
||||
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
) # must after fixing tokenizer to resize vocab
|
||||
@@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if stop is not None:
|
||||
logger.warning("Stop parameter is not supported in Huggingface engine yet.")
|
||||
logger.warning("Stop parameter is not supported by the huggingface engine yet.")
|
||||
|
||||
generating_args = generating_args.copy()
|
||||
generating_args.update(
|
||||
|
||||
@@ -13,13 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
|
||||
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1
|
||||
from ..model import load_config, load_tokenizer
|
||||
from ..model.model_utils.quantization import QuantizationMethod
|
||||
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
@@ -28,7 +29,9 @@ if is_vllm_available():
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
if is_vllm_version_greater_than_0_5():
|
||||
if is_vllm_version_greater_than_0_5_1():
|
||||
pass
|
||||
elif is_vllm_version_greater_than_0_5():
|
||||
from vllm.multimodal.image import ImagePixelData
|
||||
else:
|
||||
from vllm.sequence import MultiModalData
|
||||
@@ -53,13 +56,18 @@ class VllmEngine(BaseEngine):
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
config = load_config(model_args) # may download model from ms hub
|
||||
if getattr(config, "quantization_config", None): # gptq models should use float16
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
if quant_method == QuantizationMethod.GPTQ and model_args.infer_dtype == "auto":
|
||||
model_args.infer_dtype = "float16"
|
||||
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
self.tokenizer = tokenizer_module["tokenizer"]
|
||||
self.processor = tokenizer_module["processor"]
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format)
|
||||
self.generating_args = generating_args.to_dict()
|
||||
|
||||
engine_args = {
|
||||
@@ -124,7 +132,9 @@ class VllmEngine(BaseEngine):
|
||||
if self.processor is not None and image is not None: # add image features
|
||||
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
|
||||
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
if is_vllm_version_greater_than_0_5():
|
||||
if is_vllm_version_greater_than_0_5_1():
|
||||
multi_modal_data = {"image": pixel_values}
|
||||
elif is_vllm_version_greater_than_0_5():
|
||||
multi_modal_data = ImagePixelData(image=pixel_values)
|
||||
else: # TODO: remove vllm 0.4.3 support
|
||||
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
|
||||
|
||||
@@ -74,7 +74,7 @@ class Command(str, Enum):
|
||||
|
||||
|
||||
def main():
|
||||
command = sys.argv.pop(1)
|
||||
command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP
|
||||
if command == Command.API:
|
||||
run_api()
|
||||
elif command == Command.CHAT:
|
||||
@@ -91,7 +91,7 @@ def main():
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
||||
subprocess.run(
|
||||
process = subprocess.run(
|
||||
(
|
||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||
@@ -106,6 +106,7 @@ def main():
|
||||
),
|
||||
shell=True,
|
||||
)
|
||||
sys.exit(process.returncode)
|
||||
else:
|
||||
run_exp()
|
||||
elif command == Command.WEBDEMO:
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask
|
||||
from .data_utils import Role, split_dataset
|
||||
from .loader import get_dataset
|
||||
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
||||
@@ -21,6 +21,7 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
||||
__all__ = [
|
||||
"KTODataCollatorWithPadding",
|
||||
"PairwiseDataCollatorWithPadding",
|
||||
"SFTDataCollatorWith4DAttentionMask",
|
||||
"Role",
|
||||
"split_dataset",
|
||||
"get_dataset",
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the OpenAccess AI Collective's axolotl library.
|
||||
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,19 +16,76 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Sequence
|
||||
from typing import Any, Dict, Literal, Sequence
|
||||
|
||||
import torch
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
|
||||
r"""
|
||||
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
|
||||
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
|
||||
|
||||
e.g.
|
||||
```python
|
||||
# input
|
||||
[[1, 1, 2, 2, 2, 0]]
|
||||
# output
|
||||
[
|
||||
[
|
||||
[
|
||||
[o, x, x, x, x, x],
|
||||
[o, o, x, x, x, x],
|
||||
[x, x, o, x, x, x],
|
||||
[x, x, o, o, x, x],
|
||||
[x, x, o, o, o, x],
|
||||
[x, x, x, x, x, x],
|
||||
]
|
||||
]
|
||||
]
|
||||
```
|
||||
where `o` equals to `0.0`, `x` equals to `min_dtype`.
|
||||
"""
|
||||
bsz, seq_len = attention_mask_with_indices.size()
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
padding_mask = torch.where(expanded_mask != 0, 1, 0)
|
||||
# Create a block-diagonal mask.
|
||||
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask
|
||||
# Use the lower triangular mask to zero out the upper triangular part
|
||||
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))
|
||||
# Invert the attention mask.
|
||||
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)
|
||||
return attention_mask_4d
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for 4d attention mask.
|
||||
"""
|
||||
|
||||
block_diag_attn: bool = False
|
||||
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||
compute_dtype: "torch.dtype" = torch.float32
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
features = super().__call__(features)
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
@@ -57,7 +117,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
Data collator for KTO data.
|
||||
"""
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
target_features = []
|
||||
kl_features = []
|
||||
kto_tags = []
|
||||
|
||||
@@ -13,16 +13,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum, unique
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets
|
||||
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments
|
||||
|
||||
@@ -30,6 +29,9 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
|
||||
|
||||
@unique
|
||||
class Role(str, Enum):
|
||||
USER = "user"
|
||||
@@ -39,31 +41,29 @@ class Role(str, Enum):
|
||||
OBSERVATION = "observation"
|
||||
|
||||
|
||||
def infer_max_len(source_len: int, target_len: int, max_len: int, reserved_label_len: int) -> Tuple[int, int]:
|
||||
max_target_len = int(max_len * (target_len / (source_len + target_len)))
|
||||
max_target_len = max(max_target_len, reserved_label_len)
|
||||
max_source_len = max_len - min(max_target_len, target_len)
|
||||
return max_source_len, max_target_len
|
||||
class DatasetModule(TypedDict):
|
||||
train_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
||||
eval_dataset: Optional[Union["Dataset", "IterableDataset"]]
|
||||
|
||||
|
||||
def merge_dataset(
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
if len(all_datasets) == 1:
|
||||
return all_datasets[0]
|
||||
elif data_args.mix_strategy == "concat":
|
||||
if data_args.streaming:
|
||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||
|
||||
return concatenate_datasets(all_datasets)
|
||||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
probabilities=data_args.interleave_probs,
|
||||
seed=training_args.seed,
|
||||
seed=seed,
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||
)
|
||||
else:
|
||||
@@ -71,22 +71,17 @@ def merge_dataset(
|
||||
|
||||
|
||||
def split_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments"
|
||||
) -> Dict[str, "Dataset"]:
|
||||
if training_args.do_train:
|
||||
if data_args.val_size > 1e-6: # Split the dataset
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
val_set = dataset.take(int(data_args.val_size))
|
||||
train_set = dataset.skip(int(data_args.val_size))
|
||||
return {"train_dataset": train_set, "eval_dataset": val_set}
|
||||
else:
|
||||
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
|
||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
return {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
return {"eval_dataset": dataset}
|
||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
|
||||
) -> "DatasetDict":
|
||||
r"""
|
||||
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
|
||||
"""
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||
val_set = dataset.take(int(data_args.val_size))
|
||||
train_set = dataset.skip(int(data_args.val_size))
|
||||
return DatasetDict({"train": train_set, "validation": val_set})
|
||||
else:
|
||||
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||
dataset = dataset.train_test_split(test_size=val_size, seed=seed)
|
||||
return DatasetDict({"train": dataset["train"], "validation": dataset["test"]})
|
||||
|
||||
@@ -16,97 +16,10 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
|
||||
|
||||
DEFAULT_TOOL_PROMPT = (
|
||||
"You have access to the following tools:\n{tool_text}"
|
||||
"Use the following format if using a tool:\n"
|
||||
"```\n"
|
||||
"Action: tool name (one of [{tool_names}]).\n"
|
||||
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
||||
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
|
||||
"```\n"
|
||||
)
|
||||
|
||||
|
||||
GLM4_TOOL_PROMPT = (
|
||||
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}"
|
||||
)
|
||||
|
||||
|
||||
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
param_text = ""
|
||||
for name, param in tool["parameters"]["properties"].items():
|
||||
required = ", required" if name in tool["parameters"].get("required", []) else ""
|
||||
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
|
||||
items = (
|
||||
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
|
||||
)
|
||||
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
||||
name=name,
|
||||
type=param.get("type", ""),
|
||||
required=required,
|
||||
desc=param.get("description", ""),
|
||||
enum=enum,
|
||||
items=items,
|
||||
)
|
||||
|
||||
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
||||
)
|
||||
tool_names.append(tool["name"])
|
||||
|
||||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||
|
||||
|
||||
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
results = []
|
||||
for match in action_match:
|
||||
tool_name = match[0].strip()
|
||||
tool_input = match[1].strip().strip('"').strip("```")
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
|
||||
)
|
||||
|
||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
|
||||
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
if "\n" not in content:
|
||||
return content
|
||||
|
||||
tool_name, tool_input = content.split("\n", maxsplit=1)
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||
from .data_utils import SLOTS
|
||||
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -168,15 +81,12 @@ class StringFormatter(Formatter):
|
||||
@dataclass
|
||||
class FunctionFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
has_name, has_args = False, False
|
||||
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
||||
if "{{name}}" in slot:
|
||||
has_name = True
|
||||
if "{{arguments}}" in slot:
|
||||
has_args = True
|
||||
|
||||
if not has_name or not has_args:
|
||||
raise ValueError("Name and arguments placeholders are required in the function formatter.")
|
||||
if self.tool_format == "default":
|
||||
self.slots = DefaultToolUtils.get_function_slots() + self.slots
|
||||
elif self.tool_format == "glm4":
|
||||
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
|
||||
else:
|
||||
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
@@ -210,13 +120,13 @@ class FunctionFormatter(Formatter):
|
||||
class ToolFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
if self.tool_format == "default":
|
||||
self._tool_formatter = default_tool_formatter
|
||||
self._tool_extractor = default_tool_extractor
|
||||
self._tool_formatter = DefaultToolUtils.tool_formatter
|
||||
self._tool_extractor = DefaultToolUtils.tool_extractor
|
||||
elif self.tool_format == "glm4":
|
||||
self._tool_formatter = glm4_tool_formatter
|
||||
self._tool_extractor = glm4_tool_extractor
|
||||
self._tool_formatter = GLM4ToolUtils.tool_formatter
|
||||
self._tool_extractor = GLM4ToolUtils.tool_extractor
|
||||
else:
|
||||
raise ValueError("Tool format was not found.")
|
||||
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
|
||||
@@ -12,19 +12,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import has_tokenized_data
|
||||
from .aligner import align_dataset
|
||||
from .data_utils import merge_dataset
|
||||
from .data_utils import merge_dataset, split_dataset
|
||||
from .parser import get_dataset_list
|
||||
from .preprocess import get_preprocess_and_print_func
|
||||
from .template import get_template_and_fix_tokenizer
|
||||
@@ -35,13 +35,15 @@ if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments, ModelArguments
|
||||
from .data_utils import DatasetModule
|
||||
from .parser import DatasetAttr
|
||||
from .template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def load_single_dataset(
|
||||
def _load_single_dataset(
|
||||
dataset_attr: "DatasetAttr",
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
@@ -81,41 +83,34 @@ def load_single_dataset(
|
||||
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
||||
|
||||
if dataset_attr.load_from == "ms_hub":
|
||||
try:
|
||||
from modelscope import MsDataset
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||
from modelscope import MsDataset
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||
|
||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||
dataset = MsDataset.load(
|
||||
dataset_name=data_path,
|
||||
subset_name=data_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
)
|
||||
if isinstance(dataset, MsDataset):
|
||||
dataset = dataset.to_hf_dataset()
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||
dataset = MsDataset.load(
|
||||
dataset_name=data_path,
|
||||
subset_name=data_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=dataset_attr.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
)
|
||||
if isinstance(dataset, MsDataset):
|
||||
dataset = dataset.to_hf_dataset()
|
||||
else:
|
||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||
kwargs = {"trust_remote_code": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
name=data_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
split=dataset_attr.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
**kwargs,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
@@ -140,6 +135,66 @@ def load_single_dataset(
|
||||
return align_dataset(dataset, dataset_attr, data_args, training_args)
|
||||
|
||||
|
||||
def _get_merged_dataset(
|
||||
dataset_names: Optional[Sequence[str]],
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
if dataset_names is None:
|
||||
return None
|
||||
|
||||
datasets = []
|
||||
for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
|
||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||
|
||||
datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
||||
|
||||
return merge_dataset(datasets, data_args, seed=training_args.seed)
|
||||
|
||||
|
||||
def _get_preprocessed_dataset(
|
||||
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
is_eval: bool = False,
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
|
||||
)
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
|
||||
if training_args.should_log:
|
||||
try:
|
||||
print("eval example:" if is_eval else "training example:")
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
if stage == "pt":
|
||||
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
||||
else:
|
||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def get_dataset(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
@@ -147,8 +202,8 @@ def get_dataset(
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||
) -> "DatasetModule":
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
@@ -156,55 +211,66 @@ def get_dataset(
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset = load_from_disk(data_args.tokenized_path)
|
||||
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||
|
||||
dataset_module: Dict[str, "Dataset"] = {}
|
||||
if "train" in dataset_dict:
|
||||
dataset_module["train_dataset"] = dataset_dict["train"]
|
||||
if "validation" in dataset_dict:
|
||||
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||
|
||||
if data_args.streaming:
|
||||
dataset = dataset.to_iterable_dataset()
|
||||
return dataset
|
||||
dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()}
|
||||
|
||||
return dataset_module
|
||||
|
||||
if data_args.streaming:
|
||||
raise ValueError("Turn off `streaming` when saving dataset to disk.")
|
||||
|
||||
# Load and preprocess dataset
|
||||
with training_args.main_process_first(desc="load dataset"):
|
||||
all_datasets = []
|
||||
for dataset_attr in get_dataset_list(data_args):
|
||||
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
|
||||
raise ValueError("The dataset is not applicable in the current training stage.")
|
||||
|
||||
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
|
||||
|
||||
dataset = merge_dataset(all_datasets, data_args, training_args)
|
||||
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
|
||||
eval_dataset = _get_merged_dataset(data_args.eval_dataset, model_args, data_args, training_args, stage)
|
||||
|
||||
with training_args.main_process_first(desc="pre-process dataset"):
|
||||
preprocess_func, print_function = get_preprocess_and_print_func(
|
||||
data_args, training_args, stage, template, tokenizer, processor
|
||||
dataset = _get_preprocessed_dataset(
|
||||
dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=False
|
||||
)
|
||||
eval_dataset = _get_preprocessed_dataset(
|
||||
eval_dataset, data_args, training_args, stage, template, tokenizer, processor, is_eval=True
|
||||
)
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
||||
desc="Running tokenizer on dataset",
|
||||
)
|
||||
|
||||
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
||||
if data_args.val_size > 1e-6:
|
||||
dataset_dict = split_dataset(dataset, data_args, seed=training_args.seed)
|
||||
else:
|
||||
dataset_dict = {}
|
||||
if dataset is not None:
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
|
||||
dataset_dict["train"] = dataset
|
||||
|
||||
if eval_dataset is not None:
|
||||
if data_args.streaming:
|
||||
eval_dataset = eval_dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
|
||||
dataset_dict["validation"] = eval_dataset
|
||||
|
||||
dataset_dict = DatasetDict(dataset_dict)
|
||||
|
||||
if data_args.tokenized_path is not None:
|
||||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.tokenized_path)
|
||||
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
if training_args.should_log:
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
if stage == "pt":
|
||||
raise RuntimeError("Cannot find sufficient samples, consider increasing dataset size.")
|
||||
else:
|
||||
raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
|
||||
dataset_module = {}
|
||||
if "train" in dataset_dict:
|
||||
dataset_module["train_dataset"] = dataset_dict["train"]
|
||||
if "validation" in dataset_dict:
|
||||
dataset_module["eval_dataset"] = dataset_dict["validation"]
|
||||
|
||||
return dataset
|
||||
return dataset_module
|
||||
|
||||
@@ -15,47 +15,46 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
|
||||
from transformers.utils import cached_file
|
||||
|
||||
from ..extras.constants import DATA_CONFIG
|
||||
from ..extras.misc import use_modelscope
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
r"""
|
||||
Dataset attributes.
|
||||
"""
|
||||
|
||||
""" basic configs """
|
||||
# basic configs
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: str
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
ranking: bool = False
|
||||
""" extra configs """
|
||||
# extra configs
|
||||
subset: Optional[str] = None
|
||||
split: str = "train"
|
||||
folder: Optional[str] = None
|
||||
num_samples: Optional[int] = None
|
||||
""" common columns """
|
||||
# common columns
|
||||
system: Optional[str] = None
|
||||
tools: Optional[str] = None
|
||||
images: Optional[str] = None
|
||||
""" rlhf columns """
|
||||
# rlhf columns
|
||||
chosen: Optional[str] = None
|
||||
rejected: Optional[str] = None
|
||||
kto_tag: Optional[str] = None
|
||||
""" alpaca columns """
|
||||
# alpaca columns
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
""" sharegpt columns """
|
||||
# sharegpt columns
|
||||
messages: Optional[str] = "conversations"
|
||||
""" sharegpt tags """
|
||||
# sharegpt tags
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
@@ -71,31 +70,33 @@ class DatasetAttr:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
|
||||
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
if data_args.dataset is not None:
|
||||
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")]
|
||||
else:
|
||||
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> List["DatasetAttr"]:
|
||||
r"""
|
||||
Gets the attributes of the datasets.
|
||||
"""
|
||||
if dataset_names is None:
|
||||
dataset_names = []
|
||||
|
||||
if data_args.dataset_dir == "ONLINE":
|
||||
if dataset_dir == "ONLINE":
|
||||
dataset_info = None
|
||||
else:
|
||||
if dataset_dir.startswith("REMOTE:"):
|
||||
config_path = cached_file(path_or_repo_id=dataset_dir[7:], filename=DATA_CONFIG, repo_type="dataset")
|
||||
else:
|
||||
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
||||
|
||||
try:
|
||||
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
|
||||
with open(config_path, "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception as err:
|
||||
if len(dataset_names) != 0:
|
||||
raise ValueError(
|
||||
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
||||
)
|
||||
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
|
||||
|
||||
dataset_info = None
|
||||
|
||||
if data_args.interleave_probs is not None:
|
||||
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
|
||||
|
||||
dataset_list: List[DatasetAttr] = []
|
||||
dataset_list: List["DatasetAttr"] = []
|
||||
for name in dataset_names:
|
||||
if dataset_info is None:
|
||||
if dataset_info is None: # dataset_dir is ONLINE
|
||||
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
||||
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
||||
dataset_list.append(dataset_attr)
|
||||
@@ -120,6 +121,7 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
dataset_attr.set_attr("split", dataset_info[name], default="train")
|
||||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("num_samples", dataset_info[name])
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .template import Template
|
||||
@@ -35,11 +35,11 @@ if TYPE_CHECKING:
|
||||
|
||||
def get_preprocess_and_print_func(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
do_generate: bool = False,
|
||||
) -> Tuple[Callable, Callable]:
|
||||
if stage == "pt":
|
||||
preprocess_func = partial(
|
||||
@@ -48,8 +48,21 @@ def get_preprocess_and_print_func(
|
||||
data_args=data_args,
|
||||
)
|
||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
elif stage == "sft" and not do_generate:
|
||||
if data_args.packing:
|
||||
if data_args.neat_packing:
|
||||
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence
|
||||
|
||||
def __init__(self, data, **kwargs):
|
||||
return TypedSequence.__init__(
|
||||
self,
|
||||
data,
|
||||
type=kwargs.pop("type", None),
|
||||
try_type=kwargs.pop("try_type", None),
|
||||
optimized_int_type=kwargs.pop("optimized_int_type", None),
|
||||
)
|
||||
|
||||
OptimizedTypedSequence.__init__ = __init__
|
||||
preprocess_func = partial(
|
||||
preprocess_packed_supervised_dataset,
|
||||
template=template,
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -55,12 +55,8 @@ def _encode_feedback_example(
|
||||
else:
|
||||
kl_messages = prompt + [kl_response[1]]
|
||||
|
||||
prompt_ids, response_ids = template.encode_oneturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
_, kl_response_ids = template.encode_oneturn(
|
||||
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||
kl_prompt_ids, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
|
||||
|
||||
if template.efficient_eos:
|
||||
response_ids += [tokenizer.eos_token_id]
|
||||
@@ -69,11 +65,19 @@ def _encode_feedback_example(
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
|
||||
|
||||
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
response_ids = response_ids[:target_len]
|
||||
kl_source_len, kl_target_len = infer_seqlen(len(kl_prompt_ids), len(kl_response_ids), data_args.cutoff_len)
|
||||
kl_prompt_ids = kl_prompt_ids[:kl_source_len]
|
||||
kl_response_ids = kl_response_ids[:kl_target_len]
|
||||
|
||||
input_ids = prompt_ids + response_ids
|
||||
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
|
||||
kl_input_ids = prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * len(prompt_ids) + kl_response_ids
|
||||
labels = [IGNORE_INDEX] * source_len + response_ids
|
||||
kl_input_ids = kl_prompt_ids + kl_response_ids
|
||||
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
||||
|
||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -44,12 +44,8 @@ def _encode_pairwise_example(
|
||||
|
||||
chosen_messages = prompt + [response[0]]
|
||||
rejected_messages = prompt + [response[1]]
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer, chosen_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
_, rejected_ids = template.encode_oneturn(
|
||||
tokenizer, rejected_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, chosen_messages, system, tools)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, rejected_messages, system, tools)
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
@@ -59,10 +55,17 @@ def _encode_pairwise_example(
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
source_len, target_len = infer_seqlen(
|
||||
len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), data_args.cutoff_len
|
||||
) # consider the response is more important
|
||||
prompt_ids = prompt_ids[:source_len]
|
||||
chosen_ids = chosen_ids[:target_len]
|
||||
rejected_ids = rejected_ids[:target_len]
|
||||
|
||||
chosen_input_ids = prompt_ids + chosen_ids
|
||||
chosen_labels = [IGNORE_INDEX] * len(prompt_ids) + chosen_ids
|
||||
chosen_labels = [IGNORE_INDEX] * source_len + chosen_ids
|
||||
rejected_input_ids = prompt_ids + rejected_ids
|
||||
rejected_labels = [IGNORE_INDEX] * len(prompt_ids) + rejected_ids
|
||||
rejected_labels = [IGNORE_INDEX] * source_len + rejected_ids
|
||||
|
||||
return chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import bisect
|
||||
from typing import TYPE_CHECKING, List, Sequence
|
||||
from typing import TYPE_CHECKING, List, Sequence, Tuple
|
||||
|
||||
from ...extras.packages import is_pillow_available
|
||||
|
||||
@@ -76,3 +76,20 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
|
||||
"""
|
||||
image_seq_length = getattr(processor, "image_seq_length")
|
||||
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
||||
|
||||
|
||||
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
||||
r"""
|
||||
Computes the real sequence length after truncation by the cutoff_len.
|
||||
"""
|
||||
if target_len * 2 < cutoff_len: # truncate source
|
||||
max_target_len = cutoff_len
|
||||
elif source_len * 2 < cutoff_len: # truncate target
|
||||
max_target_len = cutoff_len - source_len
|
||||
else: # truncate both
|
||||
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
|
||||
|
||||
new_target_len = min(max_target_len, target_len)
|
||||
max_source_len = max(cutoff_len - new_target_len, 0)
|
||||
new_source_len = min(max_source_len, source_len)
|
||||
return new_source_len, new_target_len
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -51,19 +51,31 @@ def _encode_supervised_example(
|
||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||
|
||||
encoded_pairs = template.encode_multiturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||
total_length = 1 if template.efficient_eos else 0
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
||||
if total_length >= data_args.cutoff_len:
|
||||
break
|
||||
|
||||
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), data_args.cutoff_len - total_length)
|
||||
source_ids = source_ids[:source_len]
|
||||
target_ids = target_ids[:target_len]
|
||||
total_length += source_len + target_len
|
||||
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
source_label = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
source_label = [IGNORE_INDEX] * source_len
|
||||
|
||||
if data_args.mask_history and turn_idx != len(encoded_pairs) - 1:
|
||||
target_label = [IGNORE_INDEX] * target_len
|
||||
else:
|
||||
target_label = target_ids
|
||||
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
labels += source_label + target_label
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
@@ -153,22 +165,30 @@ def preprocess_packed_supervised_dataset(
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len)
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_labels = [], []
|
||||
for length in knapsack:
|
||||
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
||||
for i, length in enumerate(knapsack):
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
packed_labels += batch_labels[index]
|
||||
if data_args.neat_packing:
|
||||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||
else:
|
||||
packed_attention_masks += [1] * len(batch_input_ids[index])
|
||||
|
||||
if len(packed_input_ids) < data_args.cutoff_len:
|
||||
pad_length = data_args.cutoff_len - len(packed_input_ids)
|
||||
packed_input_ids += [tokenizer.pad_token_id] * pad_length
|
||||
packed_labels += [IGNORE_INDEX] * pad_length
|
||||
if data_args.neat_packing:
|
||||
packed_attention_masks += [0] * pad_length
|
||||
else:
|
||||
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||
|
||||
if len(packed_input_ids) != data_args.cutoff_len:
|
||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
model_inputs["attention_mask"].append([1] * data_args.cutoff_len)
|
||||
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..data_utils import Role
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
|
||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -47,9 +47,7 @@ def _encode_unsupervised_example(
|
||||
else:
|
||||
messages = prompt + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
|
||||
)
|
||||
input_ids, labels = template.encode_oneturn(tokenizer, messages, system, tools)
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
@@ -57,6 +55,9 @@ def _encode_unsupervised_example(
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
input_ids = [image_token_id] * getattr(processor, "image_seq_length") + input_ids
|
||||
|
||||
source_len, target_len = infer_seqlen(len(input_ids), len(labels), data_args.cutoff_len)
|
||||
input_ids = input_ids[:source_len]
|
||||
labels = labels[:target_len]
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role, infer_max_len
|
||||
from .data_utils import Role
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
|
||||
|
||||
@@ -48,36 +48,33 @@ class Template:
|
||||
def encode_oneturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: int = 1_000_000,
|
||||
reserved_label_len: int = 1,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
encoded_pairs = self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
prompt_ids = []
|
||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||
prompt_ids += query_ids + resp_ids
|
||||
prompt_ids = prompt_ids + encoded_pairs[-1][0]
|
||||
answer_ids = encoded_pairs[-1][1]
|
||||
for encoded_ids in encoded_messages[:-1]:
|
||||
prompt_ids += encoded_ids
|
||||
|
||||
answer_ids = encoded_messages[-1]
|
||||
return prompt_ids, answer_ids
|
||||
|
||||
def encode_multiturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
cutoff_len: int = 1_000_000,
|
||||
reserved_label_len: int = 1,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
"""
|
||||
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||
|
||||
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
r"""
|
||||
@@ -88,16 +85,14 @@ class Template:
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
) -> List[List[int]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: system + query resp
|
||||
Turn t: sep + query resp
|
||||
Turn 0: prefix + system + query resp
|
||||
Turn t: sep + query resp
|
||||
"""
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
@@ -106,10 +101,9 @@ class Template:
|
||||
|
||||
if i == 0:
|
||||
elements += self.format_prefix.apply()
|
||||
|
||||
if i == 0 and (system or tools):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
if system or tools:
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
elements += self.format_system.apply(content=(system + tool_text))
|
||||
|
||||
if i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
@@ -127,11 +121,9 @@ class Template:
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
||||
return encoded_messages
|
||||
|
||||
def _convert_elements_to_ids(
|
||||
self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
|
||||
) -> List[int]:
|
||||
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
|
||||
r"""
|
||||
Converts elements to token ids.
|
||||
"""
|
||||
@@ -152,60 +144,32 @@ class Template:
|
||||
|
||||
return token_ids
|
||||
|
||||
def _make_pairs(
|
||||
self,
|
||||
encoded_messages: Sequence[List[int]],
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
encoded_pairs = []
|
||||
total_length = 0
|
||||
for i in range(0, len(encoded_messages), 2):
|
||||
if total_length >= cutoff_len:
|
||||
break
|
||||
|
||||
max_source_len, max_target_len = infer_max_len(
|
||||
source_len=len(encoded_messages[i]),
|
||||
target_len=len(encoded_messages[i + 1]),
|
||||
max_len=(cutoff_len - total_length),
|
||||
reserved_label_len=reserved_label_len,
|
||||
)
|
||||
source_ids = encoded_messages[i][:max_source_len]
|
||||
target_ids = encoded_messages[i + 1][:max_target_len]
|
||||
total_length += len(source_ids) + len(target_ids)
|
||||
encoded_pairs.append((source_ids, target_ids))
|
||||
|
||||
return encoded_pairs
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
messages: List[Dict[str, str]],
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
cutoff_len: int,
|
||||
reserved_label_len: int,
|
||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||
) -> List[List[int]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: system + query resp
|
||||
Turn t: sep + query resp
|
||||
Turn 0: prefix + system + query resp
|
||||
Turn t: sep + query resp
|
||||
"""
|
||||
system = system or self.default_system
|
||||
encoded_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
elements = []
|
||||
|
||||
system_text = ""
|
||||
if i == 0:
|
||||
elements += self.format_prefix.apply()
|
||||
|
||||
system_text = ""
|
||||
if i == 0 and (system or tools):
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
if system or tools:
|
||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||
|
||||
if i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
@@ -223,7 +187,7 @@ class Llama2Template(Template):
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
|
||||
return encoded_messages
|
||||
|
||||
|
||||
TEMPLATES: Dict[str, Template] = {}
|
||||
@@ -240,7 +204,7 @@ def _register_template(
|
||||
format_separator: Optional["Formatter"] = None,
|
||||
format_prefix: Optional["Formatter"] = None,
|
||||
default_system: str = "",
|
||||
stop_words: List[str] = [],
|
||||
stop_words: Sequence[str] = [],
|
||||
image_token: str = "<image>",
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
@@ -275,9 +239,7 @@ def _register_template(
|
||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(
|
||||
slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots
|
||||
)
|
||||
default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
|
||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
default_prefix_formatter = EmptyFormatter()
|
||||
@@ -379,6 +341,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||
def get_template_and_fix_tokenizer(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
name: Optional[str] = None,
|
||||
tool_format: Optional[str] = None,
|
||||
) -> Template:
|
||||
if name is None:
|
||||
template = TEMPLATES["empty"] # placeholder
|
||||
@@ -387,6 +350,12 @@ def get_template_and_fix_tokenizer(
|
||||
if template is None:
|
||||
raise ValueError("Template {} does not exist.".format(name))
|
||||
|
||||
if tool_format is not None:
|
||||
logger.info("Using tool format: {}.".format(tool_format))
|
||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||
template.format_tools = ToolFormatter(tool_format=tool_format)
|
||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
|
||||
|
||||
stop_words = template.stop_words
|
||||
if template.replace_eos:
|
||||
if not stop_words:
|
||||
@@ -501,35 +470,17 @@ _register_template(
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatglm3_system",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
default_system=(
|
||||
"You are ChatGLM3, a large language model trained by Zhipu.AI. "
|
||||
"Follow the user's instructions carefully. Respond using markdown."
|
||||
),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatml",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
@@ -559,6 +510,23 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="codegeex4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
default_system=(
|
||||
"你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,"
|
||||
"并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
|
||||
),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="cohere",
|
||||
format_user=StringFormatter(
|
||||
@@ -617,22 +585,21 @@ _register_template(
|
||||
_register_template(
|
||||
name="deepseekcoder",
|
||||
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["\n<|EOT|>\n"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
default_system=(
|
||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||
"For politically sensitive questions, security and privacy issues, "
|
||||
"and other non-computer science questions, you will refuse to answer\n"
|
||||
),
|
||||
stop_words=["<|EOT|>"],
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="default",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
|
||||
format_system=StringFormatter(slots=["{{content}}\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
)
|
||||
@@ -640,7 +607,6 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="empty",
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
@@ -677,7 +643,7 @@ _register_template(
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
@@ -926,8 +892,7 @@ _register_template(
|
||||
|
||||
_register_template(
|
||||
name="zephyr",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||
default_system="You are Zephyr, a helpful assistant.",
|
||||
)
|
||||
|
||||
140
src/llamafactory/data/tool_utils.py
Normal file
140
src/llamafactory/data/tool_utils.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from .data_utils import SLOTS
|
||||
|
||||
|
||||
DEFAULT_TOOL_PROMPT = (
|
||||
"You have access to the following tools:\n{tool_text}"
|
||||
"Use the following format if using a tool:\n"
|
||||
"```\n"
|
||||
"Action: tool name (one of [{tool_names}])\n"
|
||||
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
||||
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```)\n"""
|
||||
"```\n"
|
||||
)
|
||||
|
||||
|
||||
GLM4_TOOL_PROMPT = (
|
||||
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUtils(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_function_slots() -> SLOTS: ...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
|
||||
|
||||
|
||||
class DefaultToolUtils(ToolUtils):
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
|
||||
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
param_text = ""
|
||||
for name, param in tool["parameters"]["properties"].items():
|
||||
required, enum, items = "", "", ""
|
||||
if name in tool["parameters"].get("required", []):
|
||||
required = ", required"
|
||||
|
||||
if param.get("enum", None):
|
||||
enum = ", should be one of [{}]".format(", ".join(param["enum"]))
|
||||
|
||||
if param.get("items", None):
|
||||
items = ", where each item should be {}".format(param["items"].get("type", ""))
|
||||
|
||||
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
||||
name=name,
|
||||
type=param.get("type", ""),
|
||||
required=required,
|
||||
desc=param.get("description", ""),
|
||||
enum=enum,
|
||||
items=items,
|
||||
)
|
||||
|
||||
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
||||
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
||||
)
|
||||
tool_names.append(tool["name"])
|
||||
|
||||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
results = []
|
||||
for match in action_match:
|
||||
tool_name = match[0].strip()
|
||||
tool_input = match[1].strip().strip('"').strip("```")
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class GLM4ToolUtils(ToolUtils):
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["{{name}}\n{{arguments}}"]
|
||||
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
|
||||
)
|
||||
|
||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
if "\n" not in content:
|
||||
return content
|
||||
|
||||
tool_name, tool_input = content.split("\n", maxsplit=1)
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||
@@ -37,7 +37,6 @@
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -74,8 +73,11 @@ class Evaluator:
|
||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||
|
||||
def eval(self) -> None:
|
||||
eval_task = self.eval_args.task.split("_")[0]
|
||||
eval_split = self.eval_args.task.split("_")[1]
|
||||
|
||||
mapping = cached_file(
|
||||
path_or_repo_id=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||
path_or_repo_id=os.path.join(self.eval_args.task_dir, eval_task),
|
||||
filename="mapping.json",
|
||||
cache_dir=self.model_args.cache_dir,
|
||||
token=self.model_args.hf_hub_token,
|
||||
@@ -88,27 +90,22 @@ class Evaluator:
|
||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||
results = {}
|
||||
for subject in pbar:
|
||||
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||
kwargs = {"trust_remote_code": True}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
dataset = load_dataset(
|
||||
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||
path=os.path.join(self.eval_args.task_dir, eval_task),
|
||||
name=subject,
|
||||
cache_dir=self.model_args.cache_dir,
|
||||
download_mode=self.eval_args.download_mode,
|
||||
token=self.model_args.hf_hub_token,
|
||||
**kwargs,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
pbar.set_postfix_str(categorys[subject]["name"])
|
||||
inputs, outputs, labels = [], [], []
|
||||
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
||||
for i in trange(len(dataset[eval_split]), desc="Formatting batches", position=1, leave=False):
|
||||
support_set = (
|
||||
dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
||||
)
|
||||
messages = self.eval_template.format_example(
|
||||
target_data=dataset[self.data_args.split][i],
|
||||
target_data=dataset[eval_split][i],
|
||||
support_set=support_set,
|
||||
subject_name=categorys[subject]["name"],
|
||||
)
|
||||
|
||||
@@ -78,6 +78,19 @@ TRAINING_STAGES = {
|
||||
|
||||
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
||||
|
||||
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
||||
"cohere",
|
||||
"falcon",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"llama",
|
||||
"mistral",
|
||||
"phi",
|
||||
"phi3",
|
||||
"qwen2",
|
||||
"starcoder2",
|
||||
}
|
||||
|
||||
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
||||
|
||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||
@@ -286,6 +299,17 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"CodeGeeX4-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "THUDM/codegeex4-all-9b",
|
||||
DownloadSource.MODELSCOPE: "ZhipuAI/codegeex4-all-9b",
|
||||
},
|
||||
},
|
||||
template="codegeex4",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"CodeGemma-7B": {
|
||||
@@ -507,6 +531,22 @@ register_model_group(
|
||||
"Gemma-1.1-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
|
||||
},
|
||||
"Gemma-2-9B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-9b",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
|
||||
},
|
||||
"Gemma-2-27B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-27b",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
|
||||
},
|
||||
"Gemma-2-9B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
|
||||
},
|
||||
"Gemma-2-27B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
)
|
||||
@@ -579,7 +619,26 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Jambda-v0.1": {
|
||||
"InternLM2.5-7B": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2_5-7b",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b",
|
||||
},
|
||||
"InternLM2.5-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat",
|
||||
},
|
||||
"InternLM2.5-7B-1M-Chat": {
|
||||
DownloadSource.DEFAULT: "internlm/internlm2_5-7b-chat-1m",
|
||||
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m",
|
||||
},
|
||||
},
|
||||
template="intern2",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Jamba-v0.1": {
|
||||
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
|
||||
}
|
||||
@@ -1248,6 +1307,10 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"TeleChat-1B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/TeleChat-1B",
|
||||
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-1B",
|
||||
},
|
||||
"TeleChat-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
|
||||
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -23,7 +26,7 @@ import trl
|
||||
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
||||
|
||||
|
||||
VERSION = "0.8.2"
|
||||
VERSION = "0.8.3"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's PEFT library.
|
||||
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,15 +17,13 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
|
||||
import transformers.dynamic_module_utils
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
from transformers.dynamic_module_utils import get_relative_imports
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_safetensors_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
@@ -31,15 +32,9 @@ from transformers.utils import (
|
||||
)
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
try:
|
||||
_is_bf16_available = is_torch_bf16_gpu_available()
|
||||
@@ -48,7 +43,7 @@ except Exception:
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..hparams import ModelArguments
|
||||
|
||||
@@ -78,6 +73,9 @@ class AverageMeter:
|
||||
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""
|
||||
Checks the version of the required packages.
|
||||
"""
|
||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
else:
|
||||
@@ -88,7 +86,7 @@ def check_dependencies() -> None:
|
||||
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and number of all parameters in the model.
|
||||
"""
|
||||
@@ -99,7 +97,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||
num_params = param.ds_numel
|
||||
|
||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
||||
num_bytes = param.quant_storage.itemsize
|
||||
@@ -117,52 +115,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
return trainable_params, all_param
|
||||
|
||||
|
||||
def fix_valuehead_checkpoint(
|
||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||
) -> None:
|
||||
r"""
|
||||
The model is already unwrapped.
|
||||
|
||||
There are three cases:
|
||||
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
||||
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
||||
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
||||
|
||||
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
||||
"""
|
||||
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||
return
|
||||
|
||||
if safe_serialization:
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
|
||||
decoder_state_dict = {}
|
||||
v_head_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("v_head."):
|
||||
v_head_state_dict[name] = param
|
||||
else:
|
||||
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
model.pretrained_model.save_pretrained(
|
||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
if safe_serialization:
|
||||
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||
|
||||
logger.info("Value head model saved at: {}".format(output_dir))
|
||||
|
||||
|
||||
def get_current_device() -> torch.device:
|
||||
def get_current_device() -> "torch.device":
|
||||
r"""
|
||||
Gets the current available device.
|
||||
"""
|
||||
@@ -201,7 +154,14 @@ def get_logits_processor() -> "LogitsProcessorList":
|
||||
return logits_processor
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||
r"""
|
||||
Checks if the path has a tokenized dataset.
|
||||
"""
|
||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||
|
||||
|
||||
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
"""
|
||||
@@ -220,11 +180,20 @@ def is_gpu_or_npu_available() -> bool:
|
||||
return is_torch_npu_available() or is_torch_cuda_available()
|
||||
|
||||
|
||||
def has_tokenized_data(path: os.PathLike) -> bool:
|
||||
r"""
|
||||
Checks if the path has a tokenized dataset.
|
||||
"""
|
||||
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
||||
def numpify(inputs: Union["NDArray", "torch.Tensor"]) -> "NDArray":
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
inputs = inputs.cpu()
|
||||
if inputs.dtype == torch.bfloat16: # numpy does not support bfloat16 until 1.21.4
|
||||
inputs = inputs.to(torch.float32)
|
||||
|
||||
inputs = inputs.numpy()
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def skip_check_imports() -> None:
|
||||
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
||||
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
|
||||
@@ -81,3 +81,8 @@ def is_vllm_available():
|
||||
@lru_cache
|
||||
def is_vllm_version_greater_than_0_5():
|
||||
return _get_package_version("vllm") >= version.parse("0.5.0")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_vllm_version_greater_than_0_5_1():
|
||||
return _get_package_version("vllm") >= version.parse("0.5.1")
|
||||
|
||||
@@ -31,27 +31,27 @@ class DataArguments:
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
||||
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
|
||||
)
|
||||
eval_dataset: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
||||
)
|
||||
dataset_dir: str = field(
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."},
|
||||
)
|
||||
split: str = field(
|
||||
default="train",
|
||||
metadata={"help": "Which dataset split to use for training and evaluation."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||
)
|
||||
reserved_label_len: int = field(
|
||||
default=1,
|
||||
metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
|
||||
)
|
||||
train_on_prompt: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
||||
metadata={"help": "Whether or not to disable the mask on the prompt."},
|
||||
)
|
||||
mask_history: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to mask the history and train on the last turn only."},
|
||||
)
|
||||
streaming: bool = field(
|
||||
default=False,
|
||||
@@ -87,9 +87,7 @@ class DataArguments:
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
||||
},
|
||||
metadata={"help": "Whether or not to ignore the tokens corresponding to the pad label in loss computation."},
|
||||
)
|
||||
val_size: float = field(
|
||||
default=0.0,
|
||||
@@ -97,9 +95,15 @@ class DataArguments:
|
||||
)
|
||||
packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
|
||||
},
|
||||
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
||||
)
|
||||
neat_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable sequence packing without cross-attention."},
|
||||
)
|
||||
tool_format: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Tool format to use for constructing function calling examples."},
|
||||
)
|
||||
tokenized_path: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -107,8 +111,30 @@ class DataArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.reserved_label_len >= self.cutoff_len:
|
||||
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
|
||||
def split_arg(arg):
|
||||
if isinstance(arg, str):
|
||||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
self.dataset = split_arg(self.dataset)
|
||||
self.eval_dataset = split_arg(self.eval_dataset)
|
||||
|
||||
if self.dataset is None and self.val_size > 1e-6:
|
||||
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
|
||||
|
||||
if self.eval_dataset is not None and self.val_size > 1e-6:
|
||||
raise ValueError("Cannot specify `val_size` if `eval_dataset` is not None.")
|
||||
|
||||
if self.interleave_probs is not None:
|
||||
if self.mix_strategy == "concat":
|
||||
raise ValueError("`interleave_probs` is only valid for interleaved mixing.")
|
||||
|
||||
self.interleave_probs = list(map(float, split_arg(self.interleave_probs)))
|
||||
if self.dataset is not None and len(self.dataset) != len(self.interleave_probs):
|
||||
raise ValueError("The length of dataset and interleave probs should be identical.")
|
||||
|
||||
if self.eval_dataset is not None and len(self.eval_dataset) != len(self.interleave_probs):
|
||||
raise ValueError("The length of eval dataset and interleave probs should be identical.")
|
||||
|
||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||
raise ValueError("Streaming mode should have an integer val size.")
|
||||
|
||||
@@ -113,7 +113,7 @@ class LoraArguments:
|
||||
metadata={"help": "Whether or not to initialize a PiSSA adapter."},
|
||||
)
|
||||
pissa_iter: int = field(
|
||||
default=4,
|
||||
default=16,
|
||||
metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
|
||||
)
|
||||
pissa_convert: bool = field(
|
||||
@@ -334,6 +334,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
|
||||
)
|
||||
compute_accuracy: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to compute the token-level accuracy at evaluation."},
|
||||
)
|
||||
plot_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to save the training loss curves."},
|
||||
@@ -376,14 +380,21 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
if self.use_galore and self.use_badam:
|
||||
raise ValueError("Cannot use GaLore with BAdam together.")
|
||||
|
||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
||||
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
||||
|
||||
if self.pissa_convert and self.finetuning_type != "lora":
|
||||
raise ValueError("`pissa_convert` is only valid for LoRA training.")
|
||||
|
||||
if self.pissa_convert and (self.stage in ["rm", "ppo", "kto"] or self.use_ref_model):
|
||||
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
|
||||
raise ValueError("Cannot use PiSSA for current training stage.")
|
||||
|
||||
if self.train_mm_proj_only and self.finetuning_type != "full":
|
||||
raise ValueError("`train_mm_proj_only` is only valid for full training.")
|
||||
|
||||
if self.finetuning_type != "lora":
|
||||
if self.loraplus_lr_ratio is not None:
|
||||
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
||||
|
||||
if self.use_rslora:
|
||||
raise ValueError("`use_rslora` is only valid for LoRA training.")
|
||||
|
||||
if self.use_dora:
|
||||
raise ValueError("`use_dora` is only valid for LoRA training.")
|
||||
|
||||
if self.pissa_init:
|
||||
raise ValueError("`pissa_init` is only valid for LoRA training.")
|
||||
|
||||
@@ -77,6 +77,10 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||
)
|
||||
quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field(
|
||||
default="bitsandbytes",
|
||||
metadata={"help": "Quantization method to use for on-the-fly quantization."},
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model using bitsandbytes."},
|
||||
@@ -97,7 +101,7 @@ class ModelArguments:
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
|
||||
flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
|
||||
default="auto",
|
||||
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
||||
)
|
||||
@@ -222,6 +226,7 @@ class ModelArguments:
|
||||
self.compute_dtype: Optional["torch.dtype"] = None
|
||||
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
||||
self.model_max_length: Optional[int] = None
|
||||
self.block_diag_attn: bool = False
|
||||
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
@@ -235,9 +240,6 @@ class ModelArguments:
|
||||
if self.new_special_tokens is not None: # support multiple special tokens
|
||||
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
|
||||
|
||||
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
|
||||
|
||||
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||
|
||||
@@ -252,4 +254,5 @@ class ModelArguments:
|
||||
new_arg.compute_dtype = old_arg.compute_dtype
|
||||
new_arg.device_map = old_arg.device_map
|
||||
new_arg.model_max_length = old_arg.model_max_length
|
||||
new_arg.block_diag_attn = old_arg.block_diag_attn
|
||||
return new_arg
|
||||
|
||||
@@ -79,13 +79,14 @@ def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
|
||||
def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
|
||||
def _verify_model_args(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> None:
|
||||
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Adapter is only valid for the LoRA method.")
|
||||
|
||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
@@ -102,6 +103,10 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
if data_args.template == "yi" and model_args.use_fast_tokenizer:
|
||||
logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||
model_args.use_fast_tokenizer = False
|
||||
|
||||
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
@@ -121,7 +126,7 @@ def _check_extra_dependencies(
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
require_version("badam", "To fix: pip install badam")
|
||||
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
|
||||
|
||||
if finetuning_args.plot_loss:
|
||||
require_version("matplotlib", "To fix: pip install matplotlib")
|
||||
@@ -161,6 +166,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
if finetuning_args.stage != "sft" and data_args.neat_packing:
|
||||
raise ValueError("`neat_packing` cannot be set as True except SFT.")
|
||||
|
||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
@@ -186,21 +194,38 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||
|
||||
if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED:
|
||||
raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
if training_args.do_train and data_args.dataset is None:
|
||||
raise ValueError("Please specify dataset for training.")
|
||||
|
||||
if (training_args.do_eval or training_args.do_predict) and (
|
||||
data_args.eval_dataset is None and data_args.val_size < 1e-6
|
||||
):
|
||||
raise ValueError("Please specify dataset for evaluation.")
|
||||
|
||||
if training_args.predict_with_generate and data_args.eval_dataset is None:
|
||||
raise ValueError("Cannot use `predict_with_generate` if `eval_dataset` is None.")
|
||||
|
||||
if training_args.predict_with_generate and finetuning_args.compute_accuracy:
|
||||
raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.")
|
||||
|
||||
if training_args.do_train and model_args.quantization_device_map == "auto":
|
||||
raise ValueError("Cannot use device map for quantized models in training.")
|
||||
|
||||
if finetuning_args.pissa_init and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA in DeepSpeed ZeRO-3.")
|
||||
|
||||
if finetuning_args.pure_bf16:
|
||||
if not is_torch_bf16_gpu_available():
|
||||
raise ValueError("This device does not support `pure_bf16`.")
|
||||
|
||||
if training_args.fp16 or training_args.bf16:
|
||||
raise ValueError("Turn off mixed precision training when using `pure_bf16`.")
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if (
|
||||
finetuning_args.use_galore
|
||||
@@ -209,15 +234,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||
|
||||
if (
|
||||
finetuning_args.use_badam
|
||||
and finetuning_args.badam_mode == "layer"
|
||||
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
):
|
||||
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
|
||||
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
if finetuning_args.badam_mode == "ratio":
|
||||
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
|
||||
elif not is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.")
|
||||
|
||||
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
|
||||
if finetuning_args.use_galore and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
@@ -225,7 +249,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if model_args.visual_inputs and data_args.packing:
|
||||
raise ValueError("Cannot use packing in MLLM fine-tuning.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if data_args.neat_packing and not data_args.packing:
|
||||
logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
||||
data_args.packing = True
|
||||
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
if (
|
||||
@@ -306,6 +337,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
model_args.device_map = {"": get_current_device()}
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
model_args.block_diag_attn = data_args.neat_packing
|
||||
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
|
||||
|
||||
# Log on each process the small summary
|
||||
@@ -348,7 +380,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
|
||||
if finetuning_args.stage == "rm" and model_args.visual_inputs:
|
||||
raise ValueError("Reward server does not support MLLM yet. Stay tuned.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
if model_args.export_dir is not None and model_args.export_device == "cpu":
|
||||
@@ -371,7 +403,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args)
|
||||
|
||||
model_args.device_map = "auto"
|
||||
|
||||
@@ -14,10 +14,12 @@
|
||||
|
||||
from .loader import load_config, load_model, load_tokenizer
|
||||
from .model_utils.misc import find_all_linear_modules
|
||||
from .model_utils.quantization import QuantizationMethod
|
||||
from .model_utils.valuehead import load_valuehead_params
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationMethod",
|
||||
"load_config",
|
||||
"load_model",
|
||||
"load_tokenizer",
|
||||
|
||||
@@ -289,16 +289,15 @@ def init_adapter(
|
||||
raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
|
||||
|
||||
# cast trainable parameters to float32 if:
|
||||
# 1. is_trainable and quantization_bit is not None (qlora)
|
||||
# 2. is_trainable and not deepspeed zero3 and not fsdp (zero3 or fsdp already in float32)
|
||||
# 3. is_trainable and not pure_bf16 and not badam
|
||||
# 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora)
|
||||
# 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32)
|
||||
cast_trainable_params_to_fp32 = False
|
||||
if not is_trainable:
|
||||
cast_trainable_params_to_fp32 = False
|
||||
elif model_args.quantization_bit is None and (
|
||||
is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam
|
||||
):
|
||||
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
|
||||
cast_trainable_params_to_fp32 = False
|
||||
pass
|
||||
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||
else:
|
||||
logger.info("Upcasting trainable params to float32.")
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
@@ -14,11 +14,12 @@
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import count_parameters, try_download_model_from_ms
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_ms
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.misc import register_autoclass
|
||||
from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
@@ -47,6 +48,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
|
||||
Note: including inplace operation of model_args.
|
||||
"""
|
||||
skip_check_imports()
|
||||
model_args.model_name_or_path = try_download_model_from_ms(model_args)
|
||||
return {
|
||||
"trust_remote_code": True,
|
||||
@@ -175,17 +177,21 @@ def load_model(
|
||||
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False)
|
||||
for param in model.parameters():
|
||||
if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32:
|
||||
param.data = param.data.to(model_args.compute_dtype)
|
||||
|
||||
model.eval()
|
||||
else:
|
||||
model.train()
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
if is_trainable:
|
||||
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
param_stats = "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
)
|
||||
else:
|
||||
param_stats = "all params: {:d}".format(all_param)
|
||||
param_stats = "all params: {:,}".format(all_param)
|
||||
|
||||
logger.info(param_stats)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
@@ -28,11 +29,26 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||
def configure_attn_implementation(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> None:
|
||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
|
||||
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
|
||||
if is_flash_attn_2_available():
|
||||
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
||||
require_version("flash_attn>=2.6.0", "To fix: pip install flash_attn>=2.6.0")
|
||||
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = "fa2"
|
||||
else:
|
||||
logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.")
|
||||
model_args.flash_attn = "disabled"
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")
|
||||
|
||||
if model_args.flash_attn == "auto":
|
||||
return
|
||||
|
||||
elif model_args.flash_attn == "off":
|
||||
elif model_args.flash_attn == "disabled":
|
||||
requested_attn_implementation = "eager"
|
||||
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
|
||||
@@ -78,9 +78,7 @@ def _fp32_forward_post_hook(
|
||||
return output.to(torch.float32)
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
|
||||
) -> None:
|
||||
def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
@@ -104,8 +102,8 @@ def prepare_model_for_training(
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if hasattr(model, output_layer_name) and model_args.upcast_lmhead_output:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
output_layer = getattr(model, output_layer_name)
|
||||
if model_args.upcast_lmhead_output:
|
||||
output_layer = model.get_output_embeddings()
|
||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
||||
|
||||
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
transformers_logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Modified from:
|
||||
@@ -71,8 +71,6 @@ def llama_attention_forward(
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
@@ -85,7 +83,7 @@ def llama_attention_forward(
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||
state = torch.cat(
|
||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||
@@ -156,8 +154,6 @@ def llama_flash_attention_2_forward(
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
@@ -181,7 +177,7 @@ def llama_flash_attention_2_forward(
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||
transformers_logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
@@ -191,7 +187,7 @@ def llama_flash_attention_2_forward(
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||
state = torch.cat(
|
||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||
dim=2,
|
||||
@@ -202,7 +198,7 @@ def llama_flash_attention_2_forward(
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||
|
||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||
attn_output: "torch.Tensor" = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||
)
|
||||
|
||||
@@ -238,7 +234,9 @@ def llama_sdpa_attention_forward(
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if output_attentions:
|
||||
logger.warning_once("SDPA does not support `output_attentions=True`. Falling back to the vanilla attention")
|
||||
transformers_logger.warning_once(
|
||||
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
|
||||
)
|
||||
return llama_attention_forward(
|
||||
self,
|
||||
hidden_states=hidden_states,
|
||||
@@ -275,7 +273,7 @@ def llama_sdpa_attention_forward(
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
def shift(state: "torch.Tensor") -> "torch.Tensor":
|
||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||
state = torch.cat(
|
||||
(state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)),
|
||||
@@ -291,18 +289,19 @@ def llama_sdpa_attention_forward(
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||
if query_states.device.type == "cuda" and causal_mask is not None: # avoid pytorch bug
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
|
||||
is_causal = True if causal_mask is None and q_len > 1 else False
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
is_causal=causal_mask is None and q_len > 1,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
@@ -323,7 +322,7 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
|
||||
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
||||
149
src/llamafactory/model/model_utils/packing.py
Normal file
149
src/llamafactory/model/model_utils/packing.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# Copyright 2024 Musab Gultekin and the LlamaFactory team.
|
||||
#
|
||||
# This code is based on the Musab Gultekin's functionary library.
|
||||
# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 Musab Gultekin
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers.models
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Gets the sequnce lengths in the current batch.
|
||||
|
||||
e.g.
|
||||
```python
|
||||
# input
|
||||
[
|
||||
[1, 1, 2, 2, 2, 0],
|
||||
[1, 2, 2, 3, 3, 3],
|
||||
]
|
||||
# output
|
||||
[2, 3, 1, 2, 3]
|
||||
```
|
||||
"""
|
||||
bsz = attention_mask.size(0)
|
||||
dtype, device = attention_mask.dtype, attention_mask.device
|
||||
max_num = torch.max(attention_mask).item()
|
||||
counts: "torch.Tensor" = torch.zeros((bsz, max_num), dtype=dtype, device=device)
|
||||
for i in range(max_num):
|
||||
counts[:, i] = torch.sum(attention_mask == (i + 1), dim=-1)
|
||||
|
||||
counts = counts.flatten()
|
||||
seqlens = counts[counts.nonzero().squeeze(dim=-1)]
|
||||
return seqlens
|
||||
|
||||
|
||||
def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "torch.Tensor", int]:
|
||||
r"""
|
||||
Prepares the indices and seqlens for flash attn varlen function.
|
||||
|
||||
Returns:
|
||||
indices: indices of non-masked tokens from the flattened sequence.
|
||||
cu_seqlens: the cumulative sequence lengths in the current batch, always starts from 0.
|
||||
max_seqlen_in_batch: the largest seqlen in the current batch.
|
||||
|
||||
e.g.
|
||||
```python
|
||||
# input
|
||||
[
|
||||
[1, 1, 2, 2, 2, 0],
|
||||
[1, 2, 2, 3, 3, 3],
|
||||
]
|
||||
# output
|
||||
[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11]
|
||||
[0, 2, 5, 6, 8, 11]
|
||||
3
|
||||
```
|
||||
"""
|
||||
seqlens_in_batch = get_seqlens_in_batch(attention_mask)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return indices, cu_seqlens, max_seqlen_in_batch
|
||||
|
||||
|
||||
def _patch_for_block_diag_attn(model_type: str) -> None:
|
||||
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
|
||||
if model_type == "cohere":
|
||||
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
|
||||
elif model_type == "falcon":
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
|
||||
elif model_type == "gemma":
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
|
||||
elif model_type == "gemma2":
|
||||
transformers.models.gemma2.modeling_gemma2._get_unpad_data = get_unpad_data
|
||||
elif model_type == "llama":
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data
|
||||
elif model_type == "mistral":
|
||||
transformers.models.mistral.modeling_mistral._get_unpad_data = get_unpad_data
|
||||
elif model_type == "phi":
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = get_unpad_data
|
||||
elif model_type == "phi3":
|
||||
transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data
|
||||
elif model_type == "qwen2":
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data
|
||||
elif model_type == "starcoder2":
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = get_unpad_data
|
||||
|
||||
|
||||
def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.block_diag_attn:
|
||||
return
|
||||
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
||||
_patch_for_block_diag_attn(model_type)
|
||||
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
|
||||
else:
|
||||
raise ValueError("Current model does not support block diagonal attention.")
|
||||
@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import BitsAndBytesConfig, GPTQConfig
|
||||
from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
@@ -57,9 +57,9 @@ class QuantizationMethod(str, Enum):
|
||||
HQQ = "hqq"
|
||||
|
||||
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]:
|
||||
r"""
|
||||
TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
|
||||
Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization.
|
||||
"""
|
||||
if os.path.isfile(model_args.export_quantization_dataset):
|
||||
data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
|
||||
@@ -68,20 +68,32 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
|
||||
data_path = model_args.export_quantization_dataset
|
||||
data_files = None
|
||||
|
||||
dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
|
||||
maxlen = model_args.export_quantization_maxlen
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
data_files=data_files,
|
||||
split="train",
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
)
|
||||
|
||||
samples = []
|
||||
maxlen = model_args.export_quantization_maxlen
|
||||
for _ in range(model_args.export_quantization_nsamples):
|
||||
n_try = 0
|
||||
while True:
|
||||
if n_try > 100:
|
||||
raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.")
|
||||
|
||||
sample_idx = random.randint(0, len(dataset) - 1)
|
||||
sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
if sample["input_ids"].size(1) >= maxlen:
|
||||
sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
|
||||
n_try += 1
|
||||
if sample["input_ids"].size(1) > maxlen:
|
||||
break # TODO: fix large maxlen
|
||||
|
||||
word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
|
||||
input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
|
||||
samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
|
||||
attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen]
|
||||
samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()})
|
||||
|
||||
return samples
|
||||
|
||||
@@ -93,11 +105,14 @@ def configure_quantization(
|
||||
init_kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
r"""
|
||||
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer)
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # ptq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with PTQ-quantized models.")
|
||||
if model_args.quantization_bit is not None:
|
||||
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
|
||||
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
@@ -111,7 +126,6 @@ def configure_quantization(
|
||||
require_version("autoawq", "To fix: pip install autoawq")
|
||||
|
||||
if quant_method == QuantizationMethod.AQLM:
|
||||
require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0")
|
||||
require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0")
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
@@ -119,46 +133,72 @@ def configure_quantization(
|
||||
logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper()))
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
||||
raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.")
|
||||
|
||||
require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0")
|
||||
require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
|
||||
from accelerate.utils import get_max_memory
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
raise ValueError("ChatGLM model is not supported.")
|
||||
raise ValueError("ChatGLM model is not supported yet.")
|
||||
|
||||
init_kwargs["quantization_config"] = GPTQConfig(
|
||||
bits=model_args.export_quantization_bit,
|
||||
tokenizer=tokenizer,
|
||||
dataset=_get_quantization_dataset(tokenizer, model_args),
|
||||
)
|
||||
init_kwargs["device_map"] = "auto"
|
||||
init_kwargs["max_memory"] = get_max_memory()
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
|
||||
logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit))
|
||||
|
||||
elif model_args.quantization_bit is not None: # bnb
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif model_args.quantization_bit is not None: # on-the-fly
|
||||
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type,
|
||||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
|
||||
)
|
||||
else:
|
||||
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
||||
bnb_4bit_use_double_quant=model_args.double_quantization,
|
||||
bnb_4bit_quant_type=model_args.quantization_type,
|
||||
bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora
|
||||
)
|
||||
# Do not assign device map if:
|
||||
# 1. deepspeed zero3 or fsdp (train)
|
||||
# 2. auto quantization device map (inference)
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
|
||||
|
||||
# assign device map if:
|
||||
# 1. not deepspeed zero3 and not fsdp
|
||||
# 2. not auto quantization device map
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto":
|
||||
if model_args.quantization_bit != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
|
||||
require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0")
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit))
|
||||
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
|
||||
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
|
||||
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
|
||||
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
|
||||
|
||||
require_version("hqq", "To fix: pip install hqq")
|
||||
init_kwargs["quantization_config"] = HqqConfig(
|
||||
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
||||
) # use ATEN kernel (axis=0) for performance
|
||||
logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit))
|
||||
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
|
||||
if model_args.quantization_bit != 8:
|
||||
raise ValueError("EETQ only accepts 8-bit quantization.")
|
||||
|
||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||
raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.")
|
||||
|
||||
require_version("eetq", "To fix: pip install eetq")
|
||||
init_kwargs["quantization_config"] = EetqConfig()
|
||||
logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit))
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Tuple
|
||||
import torch
|
||||
import transformers.models
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.utils import logging
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
@@ -31,6 +32,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
transformers_logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
||||
@@ -61,7 +63,7 @@ class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
||||
else:
|
||||
target_dtype = self.linear_1.weight.dtype
|
||||
|
||||
logger.warning_once("The hidden states seems to be silently casted in float32.")
|
||||
transformers_logger.warning_once("The hidden states seems to be silently casted in float32.")
|
||||
hidden_states = hidden_states.to(target_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -21,6 +21,7 @@ from peft import PeftModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import infer_optim_dtype
|
||||
@@ -29,6 +30,7 @@ from .model_utils.checkpointing import prepare_model_for_training
|
||||
from .model_utils.embedding import resize_embedding_layer
|
||||
from .model_utils.longlora import configure_longlora
|
||||
from .model_utils.moe import add_z3_leaf_module, configure_moe
|
||||
from .model_utils.packing import configure_packing
|
||||
from .model_utils.quantization import configure_quantization
|
||||
from .model_utils.rope import configure_rope
|
||||
from .model_utils.valuehead import prepare_valuehead_model
|
||||
@@ -58,21 +60,22 @@ def patch_config(
|
||||
is_trainable: bool,
|
||||
) -> None:
|
||||
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
|
||||
if model_args.infer_dtype == "auto":
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
else:
|
||||
if model_args.infer_dtype != "auto" and not is_trainable:
|
||||
model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
|
||||
else:
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
if is_torch_npu_available():
|
||||
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]
|
||||
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
|
||||
|
||||
configure_attn_implementation(config, model_args)
|
||||
configure_attn_implementation(config, model_args, is_trainable)
|
||||
configure_rope(config, model_args, is_trainable)
|
||||
configure_longlora(config, model_args, is_trainable)
|
||||
configure_quantization(config, tokenizer, model_args, init_kwargs)
|
||||
configure_moe(config, model_args, is_trainable)
|
||||
configure_visual_model(config)
|
||||
configure_packing(config, model_args, is_trainable)
|
||||
|
||||
if model_args.use_cache and not is_trainable:
|
||||
setattr(config, "use_cache", True)
|
||||
@@ -86,13 +89,16 @@ def patch_config(
|
||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2":
|
||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn
|
||||
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
|
||||
|
||||
# deepspeed zero3 is not compatible with low_cpu_mem_usage
|
||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled())
|
||||
|
||||
# cast data type of the model if:
|
||||
# 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32)
|
||||
# 2. fsdp + qlora
|
||||
if model_args.quantization_bit is not None or (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()):
|
||||
# 2. quantization_bit is not None (qlora)
|
||||
if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None:
|
||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||
|
||||
if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True
|
||||
@@ -152,6 +158,10 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||
return self.pretrained_model.get_input_embeddings()
|
||||
|
||||
def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
|
||||
if isinstance(self.pretrained_model, PreTrainedModel):
|
||||
return self.pretrained_model.get_output_embeddings()
|
||||
|
||||
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
||||
if isinstance(self.pretrained_model, PeftModel):
|
||||
self.pretrained_model.create_or_update_model_card(output_dir)
|
||||
@@ -160,4 +170,5 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
|
||||
setattr(model, "tie_weights", MethodType(tie_weights, model))
|
||||
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
|
||||
setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model))
|
||||
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))
|
||||
|
||||
@@ -22,22 +22,78 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import TrainerCallback
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_safetensors_available,
|
||||
)
|
||||
|
||||
from .constants import TRAINER_LOG
|
||||
from .logging import LoggerHandler, get_logger
|
||||
from .misc import fix_valuehead_checkpoint
|
||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.logging import LoggerHandler, get_logger
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerControl, TrainerState, TrainingArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def fix_valuehead_checkpoint(
|
||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||
) -> None:
|
||||
r"""
|
||||
The model is already unwrapped.
|
||||
|
||||
There are three cases:
|
||||
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
||||
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
||||
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
||||
|
||||
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
||||
"""
|
||||
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||
return
|
||||
|
||||
if safe_serialization:
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
else:
|
||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||
|
||||
decoder_state_dict = {}
|
||||
v_head_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("v_head."):
|
||||
v_head_state_dict[name] = param
|
||||
else:
|
||||
decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param
|
||||
|
||||
model.pretrained_model.save_pretrained(
|
||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
if safe_serialization:
|
||||
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
logger.info("Value head model saved at: {}".format(output_dir))
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
@@ -51,8 +107,70 @@ class FixValueHeadModelCallback(TrainerCallback):
|
||||
)
|
||||
|
||||
|
||||
class SaveProcessorCallback(TrainerCallback):
|
||||
def __init__(self, processor: "ProcessorMixin") -> None:
|
||||
r"""
|
||||
Initializes a callback for saving the processor.
|
||||
"""
|
||||
self.processor = processor
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
class PissaConvertCallback(TrainerCallback):
|
||||
r"""
|
||||
Initializes a callback for converting the PiSSA adapter to a normal one.
|
||||
"""
|
||||
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
model = kwargs.pop("model")
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir))
|
||||
if isinstance(model, PeftModel):
|
||||
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
|
||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
model = kwargs.pop("model")
|
||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
|
||||
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
|
||||
logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir))
|
||||
# 1. save a pissa backup with init_lora_weights: True
|
||||
# 2. save a converted lora with init_lora_weights: pissa
|
||||
# 3. load the pissa backup with init_lora_weights: True
|
||||
# 4. delete the initial adapter and change init_lora_weights to pissa
|
||||
if isinstance(model, PeftModel):
|
||||
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||
model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors)
|
||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||
model.save_pretrained(
|
||||
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir
|
||||
)
|
||||
model.load_adapter(pissa_backup_dir, "default", is_trainable=True)
|
||||
model.set_adapter("default")
|
||||
model.delete_adapter("pissa_init")
|
||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
def __init__(self) -> None:
|
||||
r"""
|
||||
Initializes a callback for logging training and evaluation status.
|
||||
"""
|
||||
@@ -70,7 +188,7 @@ class LogCallback(TrainerCallback):
|
||||
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
||||
if self.webui_mode:
|
||||
signal.signal(signal.SIGABRT, self._set_abort)
|
||||
self.logger_handler = LoggerHandler(output_dir)
|
||||
self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
@@ -29,7 +28,8 @@ from trl import DPOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler, get_batch_logps
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -54,7 +54,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
self.processor = processor
|
||||
self.f_divergence_type = "reverse_kl"
|
||||
self.reference_free = False
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
@@ -92,13 +92,17 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
||||
self.callback_handler.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
@@ -111,15 +115,6 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
if self.finetuning_args.pissa_convert:
|
||||
convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args)
|
||||
|
||||
if self.processor is not None:
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||
r"""
|
||||
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...hparams import ModelArguments
|
||||
@@ -41,7 +41,7 @@ def run_dpo(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
data_collator = PairwiseDataCollatorWithPadding(
|
||||
@@ -70,8 +70,8 @@ def run_dpo(
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
# Training
|
||||
|
||||
@@ -27,6 +27,7 @@ from trl import KTOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_batch_logps
|
||||
|
||||
|
||||
@@ -53,7 +54,6 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.finetuning_args = finetuning_args
|
||||
self.processor = processor
|
||||
self.reference_free = False
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
@@ -90,10 +90,14 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
self.ref_model.eval()
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
@@ -112,12 +116,6 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
"""
|
||||
return Trainer._get_train_sampler(self)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
if self.processor is not None:
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from ...data import KTODataCollatorWithPadding, get_dataset, split_dataset
|
||||
from ...data import KTODataCollatorWithPadding, get_dataset
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...hparams import ModelArguments
|
||||
@@ -41,7 +41,7 @@ def run_kto(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
data_collator = KTODataCollatorWithPadding(
|
||||
@@ -67,8 +67,8 @@ def run_kto(
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
# Training
|
||||
|
||||
@@ -27,6 +27,8 @@ from accelerate.utils import DistributedDataParallelKwargs
|
||||
from tqdm import tqdm
|
||||
from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.trainer import DEFAULT_CALLBACKS
|
||||
from transformers.trainer_callback import CallbackHandler
|
||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
@@ -34,9 +36,9 @@ from trl import PPOConfig, PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
from trl.models.utils import unwrap_model_for_generation
|
||||
|
||||
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
||||
from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm
|
||||
|
||||
@@ -69,15 +71,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: List["TrainerCallback"],
|
||||
callbacks: Optional[List["TrainerCallback"]],
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
reward_model: Optional["AutoModelForCausalLMWithValueHead"],
|
||||
ref_model: Optional["AutoModelForCausalLMWithValueHead"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
dataset: "Dataset",
|
||||
data_collator: "DataCollatorWithPadding",
|
||||
):
|
||||
train_dataset: Optional["Dataset"] = None,
|
||||
eval_dataset: Optional["Dataset"] = None,
|
||||
) -> None:
|
||||
if eval_dataset is not None:
|
||||
raise NotImplementedError("PPOTrainer does not support eval dataset yet.")
|
||||
|
||||
backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
|
||||
ppo_config = PPOConfig(
|
||||
model_name=model_args.model_name_or_path,
|
||||
@@ -99,18 +105,23 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
)
|
||||
|
||||
# Add deepspeed config
|
||||
ppo_config.accelerator_kwargs["kwargs_handlers"] = [
|
||||
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
|
||||
]
|
||||
if training_args.deepspeed_plugin is not None:
|
||||
ppo_config.accelerator_kwargs["kwargs_handlers"] = [
|
||||
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
|
||||
]
|
||||
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
||||
if ppo_config.log_with is not None:
|
||||
logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
|
||||
ppo_config.log_with = None
|
||||
|
||||
# Create optimizer and scheduler
|
||||
if training_args.max_steps > 0:
|
||||
num_training_steps = training_args.max_steps
|
||||
else:
|
||||
total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(
|
||||
len(train_dataset) / total_train_batch_size
|
||||
)
|
||||
|
||||
optimizer = self.create_optimizer(model, training_args, finetuning_args)
|
||||
scheduler = self.create_scheduler(training_args, num_training_steps, optimizer)
|
||||
@@ -121,7 +132,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dataset,
|
||||
dataset=train_dataset,
|
||||
data_collator=data_collator,
|
||||
lr_scheduler=scheduler,
|
||||
)
|
||||
@@ -131,7 +142,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.finetuning_args = finetuning_args
|
||||
self.reward_model = reward_model
|
||||
self.current_device = get_current_device() # patch for deepspeed training
|
||||
self.processor = processor
|
||||
|
||||
self.generation_config = GenerationConfig(
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
@@ -143,16 +153,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.control = TrainerControl()
|
||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
|
||||
|
||||
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
|
||||
self.callback_handler = CallbackHandler(
|
||||
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
if self.args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
|
||||
|
||||
self.amp_context = torch.autocast(self.current_device.type, dtype=self.model_args.compute_dtype)
|
||||
self.amp_context = torch.autocast(self.current_device.type)
|
||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||
|
||||
if finetuning_args.reward_model_type == "full":
|
||||
@@ -165,10 +173,16 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
else:
|
||||
self.reward_model = self.accelerator.prepare_model(self.reward_model, evaluation_mode=True)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
self.add_callback(FixValueHeadModelCallback)
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
|
||||
r"""
|
||||
@@ -202,23 +216,23 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
if self.is_world_process_zero():
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = {}".format(num_examples))
|
||||
logger.info(" Num Epochs = {}".format(num_train_epochs))
|
||||
logger.info(" Instantaneous batch size per device = {}".format(self.args.per_device_train_batch_size))
|
||||
logger.info(" Num examples = {:,}".format(num_examples))
|
||||
logger.info(" Num Epochs = {:,}".format(num_train_epochs))
|
||||
logger.info(" Instantaneous batch size per device = {:,}".format(self.args.per_device_train_batch_size))
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {}".format(
|
||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
|
||||
total_train_batch_size
|
||||
)
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = {}".format(self.args.gradient_accumulation_steps))
|
||||
logger.info(" Num optimization epochs per batch = {}".format(self.finetuning_args.ppo_epochs))
|
||||
logger.info(" Total training steps = {}".format(max_steps))
|
||||
logger.info(" Number of trainable parameters = {}".format(count_parameters(self.model)[0]))
|
||||
logger.info(" Gradient Accumulation steps = {:,}".format(self.args.gradient_accumulation_steps))
|
||||
logger.info(" Num optimization epochs per batch = {:,}".format(self.finetuning_args.ppo_epochs))
|
||||
logger.info(" Total training steps = {:,}".format(max_steps))
|
||||
logger.info(" Number of trainable parameters = {:,}".format(count_parameters(self.model)[0]))
|
||||
|
||||
dataiter = iter(self.dataloader)
|
||||
loss_meter = AverageMeter()
|
||||
reward_meter = AverageMeter()
|
||||
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
||||
self.callback_handler.on_train_begin(self.args, self.state, self.control)
|
||||
|
||||
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
|
||||
try:
|
||||
@@ -256,7 +270,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
logger.warning("Failed to save stats due to unknown errors.")
|
||||
|
||||
self.state.global_step += 1
|
||||
self.log_callback.on_step_end(self.args, self.state, self.control)
|
||||
self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||
|
||||
if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
|
||||
logs = dict(
|
||||
@@ -268,7 +282,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
tqdm.write(str(logs))
|
||||
logs["step"] = step
|
||||
self.state.log_history.append(logs)
|
||||
self.log_callback.on_log(self.args, self.state, self.control)
|
||||
self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||
loss_meter.reset()
|
||||
reward_meter.reset()
|
||||
|
||||
@@ -276,17 +290,12 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.save_model(
|
||||
os.path.join(self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step))
|
||||
)
|
||||
self.save_callback.on_save(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
break
|
||||
|
||||
self.log_callback.on_train_end(self.args, self.state, self.control)
|
||||
self.save_callback.on_train_end(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
|
||||
def create_optimizer(
|
||||
self,
|
||||
@@ -337,11 +346,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
batch[k] = v[:, start_index:]
|
||||
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model) # issue in trl v0.8.6
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
if self.model_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(unwrapped_model)
|
||||
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generate_output: "torch.Tensor" = unwrapped_model.generate(
|
||||
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||
)
|
||||
if self.model_args.upcast_layernorm:
|
||||
@@ -352,12 +361,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
queries, responses = [], []
|
||||
for i in range(len(query)):
|
||||
query_start_index = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
response_indexes = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
|
||||
if len(response_index) == 0:
|
||||
response_length = 1 # allow empty response
|
||||
if len(response_indexes) == 0: # allow empty response
|
||||
response_length = 1
|
||||
elif self.tokenizer.eos_token_id == self.tokenizer.pad_token_id: # include eos token
|
||||
response_length = response_indexes[-1].item() + 2
|
||||
else:
|
||||
response_length = response_index[-1].item() + 1
|
||||
response_length = response_indexes[-1].item() + 1
|
||||
|
||||
queries.append(query[i, query_start_index:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
@@ -380,7 +391,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||
return get_rewards_from_server(self.reward_model, messages)
|
||||
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
@@ -390,21 +401,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
reward_model = self.reward_model
|
||||
|
||||
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
|
||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True, use_cache=False)
|
||||
_, _, values = reward_model(**batch, return_dict=True, use_cache=False)
|
||||
|
||||
if self.finetuning_args.reward_model_type == "lora":
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
if self.is_chatglm_model: # assume same architecture
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
rewards = []
|
||||
for i in range(values.size(0)):
|
||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
return rewards
|
||||
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||
return rewards.float().detach() # use fp32 type
|
||||
|
||||
@PPODecorators.empty_device_cache()
|
||||
def batched_forward_pass(
|
||||
@@ -438,10 +441,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
attention_mask = input_kwargs["attention_mask"]
|
||||
|
||||
with self.amp_context: # support bf16
|
||||
logits, _, values = model(**input_kwargs)
|
||||
|
||||
if self.is_chatglm_model:
|
||||
values = torch.transpose(values, 0, 1)
|
||||
logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False)
|
||||
|
||||
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
||||
masks = torch.zeros_like(attention_mask)
|
||||
@@ -503,8 +503,5 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.model.save_checkpoint(output_dir)
|
||||
|
||||
elif self.args.should_save:
|
||||
self._save(output_dir)
|
||||
|
||||
if self.processor is not None and self.args.should_save:
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
self._save(output_dir, state_dict=unwrapped_model.state_dict())
|
||||
|
||||
@@ -20,10 +20,9 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
from ...data import get_dataset
|
||||
from ...extras.callbacks import FixValueHeadModelCallback
|
||||
from ...extras.misc import fix_valuehead_checkpoint
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..callbacks import fix_valuehead_checkpoint
|
||||
from ..trainer_utils import create_ref_model, create_reward_model
|
||||
from .trainer import CustomPPOTrainer
|
||||
|
||||
@@ -44,7 +43,7 @@ def run_ppo(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
@@ -55,17 +54,17 @@ def run_ppo(
|
||||
reward_model = create_reward_model(model, model_args, finetuning_args)
|
||||
|
||||
# Initialize our Trainer
|
||||
ppo_trainer = CustomPPOTrainer(
|
||||
ppo_trainer: "CustomPPOTrainer" = CustomPPOTrainer(
|
||||
model_args=model_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||
callbacks=callbacks,
|
||||
model=model,
|
||||
reward_model=reward_model,
|
||||
ref_model=ref_model,
|
||||
dataset=dataset,
|
||||
data_collator=data_collator,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
)
|
||||
|
||||
@@ -75,6 +74,7 @@ def run_ppo(
|
||||
ppo_trainer.save_model()
|
||||
if training_args.should_save:
|
||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||
|
||||
@@ -12,14 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..trainer_utils import convert_pissa_adapter, create_custom_optimzer, create_custom_scheduler
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -42,15 +42,18 @@ class CustomTrainer(Trainer):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self.processor = processor
|
||||
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
self.save_model(os.path.join(self.args.output_dir, "pissa_init"))
|
||||
self.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
@@ -62,12 +65,3 @@ class CustomTrainer(Trainer):
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
if self.finetuning_args.pissa_convert:
|
||||
convert_pissa_adapter(output_dir, state_dict, self.accelerator, self.model, self.args)
|
||||
|
||||
if self.processor is not None:
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
from ...data import get_dataset, split_dataset
|
||||
from ...data import get_dataset
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
@@ -42,7 +42,7 @@ def run_pt(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
@@ -53,8 +53,8 @@ def run_pt(
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
# Training
|
||||
|
||||
@@ -12,11 +12,38 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...extras.misc import numpify
|
||||
|
||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
preds, _ = eval_preds
|
||||
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import EvalPrediction
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeAccuracy:
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
|
||||
self.score_dict = {"accuracy": []}
|
||||
return result
|
||||
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
|
||||
if not chosen_scores.shape:
|
||||
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
|
||||
else:
|
||||
for i in range(len(chosen_scores)):
|
||||
self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i])
|
||||
|
||||
if compute_result:
|
||||
return self._dump()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the CarperAI's trlx library.
|
||||
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,28 +14,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2022 CarperAI
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import json
|
||||
import os
|
||||
@@ -46,6 +24,7 @@ import torch
|
||||
from transformers import Trainer
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
||||
|
||||
|
||||
@@ -69,12 +48,20 @@ class PairwiseTrainer(Trainer):
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self.processor = processor
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
self.add_callback(FixValueHeadModelCallback)
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
if processor is not None:
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.pissa_convert:
|
||||
self.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
@@ -87,12 +74,6 @@ class PairwiseTrainer(Trainer):
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
if self.processor is not None:
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
@@ -102,49 +83,21 @@ class PairwiseTrainer(Trainer):
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Note that the first element will be removed from the output tuple.
|
||||
See: https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/trainer.py#L3777
|
||||
See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
|
||||
"""
|
||||
# Compute rewards
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
|
||||
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
# Split the inputs and rewards into two parts, chosen and rejected
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False)
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
|
||||
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
|
||||
chosen_scores, rejected_scores = [], []
|
||||
chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0)
|
||||
chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0)
|
||||
chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1))
|
||||
rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1))
|
||||
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
|
||||
|
||||
# Compute pairwise loss. Only backprop on the different tokens before padding
|
||||
loss = 0
|
||||
for i in range(batch_size):
|
||||
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
|
||||
|
||||
if len(check_divergence) == 0:
|
||||
end_index = chosen_length
|
||||
div_index = end_index - 1
|
||||
else:
|
||||
end_index = max(chosen_length, rejected_length)
|
||||
div_index = check_divergence[0]
|
||||
|
||||
assert div_index > 0
|
||||
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
||||
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
||||
if return_outputs: # use the score on the last token except pad token for inference
|
||||
chosen_scores.append(chosen_rewards[i, chosen_length - 1])
|
||||
rejected_scores.append(rejected_rewards[i, rejected_length - 1])
|
||||
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
||||
|
||||
loss = loss / batch_size
|
||||
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
||||
if return_outputs:
|
||||
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
|
||||
return loss, [loss, chosen_scores, rejected_scores]
|
||||
|
||||
return loss
|
||||
return loss, (loss, chosen_scores, rejected_scores)
|
||||
else:
|
||||
return loss
|
||||
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
r"""
|
||||
@@ -163,4 +116,5 @@ class PairwiseTrainer(Trainer):
|
||||
res: List[str] = []
|
||||
for c_score, r_score in zip(chosen_scores, rejected_scores):
|
||||
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
|
||||
|
||||
writer.write("\n".join(res))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the CarperAI's trlx library.
|
||||
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,38 +14,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2022 CarperAI
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
|
||||
from ...extras.callbacks import FixValueHeadModelCallback
|
||||
from ...extras.misc import fix_valuehead_checkpoint
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..callbacks import fix_valuehead_checkpoint
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
from .metric import compute_accuracy
|
||||
from .metric import ComputeAccuracy
|
||||
from .trainer import PairwiseTrainer
|
||||
|
||||
|
||||
@@ -64,7 +41,7 @@ def run_rm(
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
dataset_module = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
|
||||
@@ -77,10 +54,10 @@ def run_rm(
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||
compute_metrics=compute_accuracy,
|
||||
callbacks=callbacks,
|
||||
compute_metrics=ComputeAccuracy(),
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
# Training
|
||||
@@ -89,6 +66,7 @@ def run_rm(
|
||||
trainer.save_model()
|
||||
if training_args.should_save:
|
||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
@@ -103,7 +81,7 @@ def run_rm(
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
predict_results = trainer.predict(dataset, metric_key_prefix="predict")
|
||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict")
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results)
|
||||
|
||||
@@ -17,17 +17,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.utils import is_jieba_available, is_nltk_available
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.misc import numpify
|
||||
from ...extras.packages import is_rouge_available
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import EvalPrediction, PreTrainedTokenizer
|
||||
|
||||
|
||||
if is_jieba_available():
|
||||
@@ -42,20 +44,64 @@ if is_rouge_available():
|
||||
from rouge_chinese import Rouge
|
||||
|
||||
|
||||
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
|
||||
if isinstance(logits, (list, tuple)):
|
||||
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
|
||||
logits = logits[0]
|
||||
else: # moe models have aux loss
|
||||
logits = logits[1]
|
||||
|
||||
if logits.dim() != 3:
|
||||
raise ValueError("Cannot process the logits.")
|
||||
|
||||
return torch.argmax(logits, dim=-1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeMetrics:
|
||||
class ComputeAccuracy:
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
|
||||
self.score_dict = {"accuracy": []}
|
||||
return result
|
||||
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
|
||||
for i in range(len(preds)):
|
||||
pred, label = preds[i, :-1], labels[i, 1:]
|
||||
label_mask = label != IGNORE_INDEX
|
||||
self.score_dict["accuracy"].append(np.mean(pred[label_mask] == label[label_mask]))
|
||||
|
||||
if compute_result:
|
||||
return self._dump()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeSimilarity:
|
||||
r"""
|
||||
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
||||
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
|
||||
"""
|
||||
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
|
||||
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
r"""
|
||||
Uses the model predictions to compute metrics.
|
||||
"""
|
||||
preds, labels = eval_preds
|
||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
|
||||
self.score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
return result
|
||||
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
|
||||
|
||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||
@@ -75,9 +121,10 @@ class ComputeMetrics:
|
||||
result = scores[0]
|
||||
|
||||
for k, v in result.items():
|
||||
score_dict[k].append(round(v["f"] * 100, 4))
|
||||
self.score_dict[k].append(round(v["f"] * 100, 4))
|
||||
|
||||
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||
self.score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||
|
||||
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
||||
if compute_result:
|
||||
return self._dump()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user