Compare commits
259 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
57354fc990 | ||
|
|
89f240805c | ||
|
|
27bbea886c | ||
|
|
3ec3dda33a | ||
|
|
ae9f338bf7 | ||
|
|
bf44f76dc7 | ||
|
|
c18581f0a4 | ||
|
|
9f6c5c4798 | ||
|
|
7bc03ac986 | ||
|
|
85d7e4f4ab | ||
|
|
bf69747f40 | ||
|
|
f1146bf7b6 | ||
|
|
9efd1fec90 | ||
|
|
3b91839a55 | ||
|
|
bc4421eeef | ||
|
|
5003820a6a | ||
|
|
cd2485f28d | ||
|
|
918a367378 | ||
|
|
3d35aeca72 | ||
|
|
53b1e5fd1d | ||
|
|
b852c895cf | ||
|
|
aaa7ed8712 | ||
|
|
205aca5b03 | ||
|
|
87b1f851f1 | ||
|
|
fca814b30d | ||
|
|
a20c2b6ecf | ||
|
|
fee94e1c54 | ||
|
|
047a596542 | ||
|
|
3d45606984 | ||
|
|
310c107d56 | ||
|
|
089e4d9e96 | ||
|
|
ae56c3cf49 | ||
|
|
0a0288a286 | ||
|
|
25da686758 | ||
|
|
e2da3cc9fa | ||
|
|
c42e5cf401 | ||
|
|
9943cd1c96 | ||
|
|
1e6f96508a | ||
|
|
d401974f69 | ||
|
|
09b2dbe859 | ||
|
|
7f8ef8c132 | ||
|
|
fcb6283a72 | ||
|
|
0027f46ccc | ||
|
|
967a27695e | ||
|
|
3ce8a326c6 | ||
|
|
91b56b7baf | ||
|
|
e2fa961302 | ||
|
|
87d6d7dc61 | ||
|
|
00019e2ca4 | ||
|
|
b104739d63 | ||
|
|
b238d1aa04 | ||
|
|
aa497d5d96 | ||
|
|
fecf04b2f4 | ||
|
|
3f157e2f6f | ||
|
|
c7c558562e | ||
|
|
c2ea5fb618 | ||
|
|
fa9c32bb8d | ||
|
|
c610deb5a2 | ||
|
|
2bb3255e74 | ||
|
|
b28b74c71e | ||
|
|
1ed921bff7 | ||
|
|
80f634cc95 | ||
|
|
a3eb5e200c | ||
|
|
2d02c0e22d | ||
|
|
093eda2ad6 | ||
|
|
dbaf621f57 | ||
|
|
ceb701c2d4 | ||
|
|
29ad3783f5 | ||
|
|
fa2386e73c | ||
|
|
e0045e8386 | ||
|
|
b94c941196 | ||
|
|
ba66ac084f | ||
|
|
83479c9ef0 | ||
|
|
df8ac15ef0 | ||
|
|
8cea5cd967 | ||
|
|
a2d7d6a518 | ||
|
|
a63e624eca | ||
|
|
8596c321ce | ||
|
|
54cd799aa0 | ||
|
|
8185eb1890 | ||
|
|
03213984ec | ||
|
|
aeeee9d4b5 | ||
|
|
c8a1fb99bf | ||
|
|
f0181a41ff | ||
|
|
f6b06d0c6f | ||
|
|
1047217f78 | ||
|
|
16a9a44849 | ||
|
|
58fb24ce41 | ||
|
|
a9afffa246 | ||
|
|
1fdd053022 | ||
|
|
0a833968a0 | ||
|
|
58b681de78 | ||
|
|
22d5fc5f4c | ||
|
|
cc0119f698 | ||
|
|
580cedebde | ||
|
|
43bd1b070c | ||
|
|
42aa9c65be | ||
|
|
b0b87fa33f | ||
|
|
22912eba1a | ||
|
|
e2748fa967 | ||
|
|
248d5daaff | ||
|
|
8f5921692e | ||
|
|
e880eb8844 | ||
|
|
dc076c4e52 | ||
|
|
8306e93ef3 | ||
|
|
6a2cd129c0 | ||
|
|
30d7f6a22e | ||
|
|
5440ebbae6 | ||
|
|
22dbe694e9 | ||
|
|
64ac6ca396 | ||
|
|
377d37fa7f | ||
|
|
55296744a8 | ||
|
|
d0889012c2 | ||
|
|
3a8b2890eb | ||
|
|
5b2284a51d | ||
|
|
4807d8a4ef | ||
|
|
c6e1313977 | ||
|
|
66819fd3ee | ||
|
|
bd85e370be | ||
|
|
cc097174cc | ||
|
|
7d135bbdb8 | ||
|
|
4845a76535 | ||
|
|
67645c0db8 | ||
|
|
f463b3f038 | ||
|
|
01defc2779 | ||
|
|
c9e77ab352 | ||
|
|
c3de160d1c | ||
|
|
3693d7b571 | ||
|
|
a63144c28f | ||
|
|
2b3b0473cd | ||
|
|
9d929897ce | ||
|
|
313a5e1494 | ||
|
|
74dd25224a | ||
|
|
c7efc7f2ed | ||
|
|
c71c78da50 | ||
|
|
f4897da009 | ||
|
|
a6951db970 | ||
|
|
9d27aaa38f | ||
|
|
3b19b6f31b | ||
|
|
5b15ca0b0b | ||
|
|
aad79127e6 | ||
|
|
c42dcab32b | ||
|
|
be519c84d9 | ||
|
|
b2dc6dc59a | ||
|
|
9df626dc18 | ||
|
|
8d4b9200a1 | ||
|
|
7806df46ba | ||
|
|
bba026a212 | ||
|
|
6e111eb29f | ||
|
|
2b69ae0eb2 | ||
|
|
13d73574ef | ||
|
|
bc264807ae | ||
|
|
f9815dd20a | ||
|
|
1f58943b32 | ||
|
|
6476507429 | ||
|
|
35862d19ec | ||
|
|
1272cb00df | ||
|
|
e9ac26db4c | ||
|
|
20ee1d2e19 | ||
|
|
cbc1dd0c88 | ||
|
|
870bbabbc4 | ||
|
|
8fd84c375e | ||
|
|
32b5364051 | ||
|
|
cf72aec098 | ||
|
|
87849d12d2 | ||
|
|
a19512436f | ||
|
|
6c89d93aea | ||
|
|
345f40a660 | ||
|
|
8b9a814653 | ||
|
|
05fabf9095 | ||
|
|
95eede911a | ||
|
|
7bc7f7d673 | ||
|
|
054fdbe186 | ||
|
|
f0f80819a0 | ||
|
|
e702678252 | ||
|
|
553579986a | ||
|
|
622cb04f27 | ||
|
|
f3ba11a432 | ||
|
|
8b1f53bca5 | ||
|
|
ac25fef80e | ||
|
|
15f819d273 | ||
|
|
f2d1c43d28 | ||
|
|
464acc7d6c | ||
|
|
a96c5da737 | ||
|
|
28d09b81c9 | ||
|
|
a769d0e3d4 | ||
|
|
1b98b5e65c | ||
|
|
3cc5408da7 | ||
|
|
689f5c4554 | ||
|
|
ab5d042cd3 | ||
|
|
4d43317aa1 | ||
|
|
ed3b0c5b40 | ||
|
|
67a97794ee | ||
|
|
2c7c93cb9b | ||
|
|
4d4fe08d14 | ||
|
|
85a919b6f7 | ||
|
|
fe2abe20fc | ||
|
|
12444720db | ||
|
|
510faf5805 | ||
|
|
722e01c8ab | ||
|
|
6050e6cff9 | ||
|
|
c8abbe4fc3 | ||
|
|
f2881c9d4a | ||
|
|
1ded3abdf1 | ||
|
|
e641f1215a | ||
|
|
ca736bcab7 | ||
|
|
bddb2646bd | ||
|
|
e4c57f54f8 | ||
|
|
6de82ca843 | ||
|
|
b2c02df555 | ||
|
|
ca86d6361e | ||
|
|
b6fb00e046 | ||
|
|
86c84972c8 | ||
|
|
9390927875 | ||
|
|
c4a585f232 | ||
|
|
300feb3245 | ||
|
|
cacafb0038 | ||
|
|
6509114259 | ||
|
|
7d4cb79822 | ||
|
|
b867e164fe | ||
|
|
26bbfc084d | ||
|
|
c376eed31d | ||
|
|
7c595abc38 | ||
|
|
c428ab68d8 | ||
|
|
968b9f1852 | ||
|
|
018266c66e | ||
|
|
111c644bf1 | ||
|
|
de72d1f0e7 | ||
|
|
8bfb856923 | ||
|
|
8fdbaab95d | ||
|
|
a01668bbe8 | ||
|
|
3385616a37 | ||
|
|
1f0d89328d | ||
|
|
a7feab45d5 | ||
|
|
f34322afd7 | ||
|
|
3815fa40b7 | ||
|
|
c43050b3fa | ||
|
|
3e152872ad | ||
|
|
ae6ad55758 | ||
|
|
0118a2fc04 | ||
|
|
4dd81976f4 | ||
|
|
2b4da8baf6 | ||
|
|
7d1b4071e8 | ||
|
|
8fc5377f50 | ||
|
|
e5812f261d | ||
|
|
f7e85cd7de | ||
|
|
749395420b | ||
|
|
7d536d1d75 | ||
|
|
7fd0d2fc2f | ||
|
|
ec696bbcdd | ||
|
|
df24345d65 | ||
|
|
386dd26097 | ||
|
|
514f976cc1 | ||
|
|
66b870fd08 | ||
|
|
24d3c7e378 | ||
|
|
484128b641 | ||
|
|
588ea95732 | ||
|
|
800567cde7 | ||
|
|
7a3ba5a25d |
@@ -7,6 +7,8 @@ data
|
||||
docker
|
||||
saves
|
||||
hf_cache
|
||||
ms_cache
|
||||
om_cache
|
||||
output
|
||||
.dockerignore
|
||||
.gitattributes
|
||||
|
||||
18
.env.local
18
.env.local
@@ -1,33 +1,35 @@
|
||||
# Note: actually we do not support .env, just for reference
|
||||
# api
|
||||
API_HOST=0.0.0.0
|
||||
API_PORT=8000
|
||||
API_HOST=
|
||||
API_PORT=
|
||||
API_KEY=
|
||||
API_MODEL_NAME=gpt-3.5-turbo
|
||||
API_MODEL_NAME=
|
||||
FASTAPI_ROOT_PATH=
|
||||
MAX_CONCURRENT=
|
||||
# general
|
||||
DISABLE_VERSION_CHECK=
|
||||
FORCE_CHECK_IMPORTS=
|
||||
FORCE_TORCHRUN=
|
||||
LLAMAFACTORY_VERBOSITY=
|
||||
USE_MODELSCOPE_HUB=
|
||||
USE_OPENMIND_HUB=
|
||||
RECORD_VRAM=
|
||||
# torchrun
|
||||
FORCE_TORCHRUN=
|
||||
MASTER_ADDR=
|
||||
MASTER_PORT=
|
||||
NNODES=
|
||||
RANK=
|
||||
NODE_RANK=
|
||||
NPROC_PER_NODE=
|
||||
# wandb
|
||||
WANDB_DISABLED=
|
||||
WANDB_PROJECT=huggingface
|
||||
WANDB_PROJECT=
|
||||
WANDB_API_KEY=
|
||||
# gradio ui
|
||||
GRADIO_SHARE=False
|
||||
GRADIO_SERVER_NAME=0.0.0.0
|
||||
GRADIO_SHARE=
|
||||
GRADIO_SERVER_NAME=
|
||||
GRADIO_SERVER_PORT=
|
||||
GRADIO_ROOT_PATH=
|
||||
GRADIO_IPV6=
|
||||
# setup
|
||||
ENABLE_SHORT_CONSOLE=1
|
||||
# reserved (do not use)
|
||||
|
||||
46
.github/CONTRIBUTING.md
vendored
46
.github/CONTRIBUTING.md
vendored
@@ -19,3 +19,49 @@ There are several ways you can contribute to LLaMA Factory:
|
||||
### Style guide
|
||||
|
||||
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
||||
|
||||
### Create a Pull Request
|
||||
|
||||
1. Fork the [repository](https://github.com/hiyouga/LLaMA-Factory) by clicking on the [Fork](https://github.com/hiyouga/LLaMA-Factory/fork) button on the repository's page. This creates a copy of the code under your GitHub user account.
|
||||
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote:
|
||||
|
||||
```bash
|
||||
git clone git@github.com:[username]/LLaMA-Factory.git
|
||||
cd LLaMA-Factory
|
||||
git remote add upstream https://github.com/hiyouga/LLaMA-Factory.git
|
||||
```
|
||||
|
||||
3. Create a new branch to hold your development changes:
|
||||
|
||||
```bash
|
||||
git checkout -b dev_your_branch
|
||||
```
|
||||
|
||||
4. Set up a development environment by running the following command in a virtual environment:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
If LLaMA Factory was already installed in the virtual environment, remove it with `pip uninstall llamafactory` before reinstalling it in editable mode with the -e flag.
|
||||
|
||||
5. Check code before commit:
|
||||
|
||||
```bash
|
||||
make commit
|
||||
make style && make quality
|
||||
make test
|
||||
```
|
||||
|
||||
6. Submit changes:
|
||||
|
||||
```bash
|
||||
git add .
|
||||
git commit -m "commit message"
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
git push -u origin dev_your_branch
|
||||
```
|
||||
|
||||
7. Create a merge request from your branch `dev_your_branch` at [origin repo](https://github.com/hiyouga/LLaMA-Factory).
|
||||
|
||||
6
.github/workflows/publish.yml
vendored
6
.github/workflows/publish.yml
vendored
@@ -26,15 +26,15 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.8"
|
||||
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install build
|
||||
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
python -m build
|
||||
|
||||
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
||||
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.8"
|
||||
- "3.8" # TODO: remove py38 in next transformers release
|
||||
- "3.9"
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
@@ -54,7 +54,6 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install git+https://github.com/huggingface/transformers.git
|
||||
python -m pip install ".[torch,dev]"
|
||||
|
||||
- name: Check quality
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -159,9 +159,13 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# vscode
|
||||
.vscode/
|
||||
|
||||
# custom .gitignore
|
||||
ms_cache/
|
||||
hf_cache/
|
||||
om_cache/
|
||||
cache/
|
||||
config/
|
||||
saves/
|
||||
|
||||
28
.pre-commit-config.yaml
Normal file
28
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-ast
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=25000']
|
||||
- id: check-merge-conflict
|
||||
- id: check-yaml
|
||||
- id: debug-statements
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
args: [--markdown-linebreak-ext=md]
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.17.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py38-plus]
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.9
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
11
Makefile
11
Makefile
@@ -1,7 +1,14 @@
|
||||
.PHONY: quality style test
|
||||
.PHONY: build commit quality style test
|
||||
|
||||
check_dirs := scripts src tests setup.py
|
||||
|
||||
build:
|
||||
pip install build && python -m build
|
||||
|
||||
commit:
|
||||
pre-commit install
|
||||
pre-commit run --all-files
|
||||
|
||||
quality:
|
||||
ruff check $(check_dirs)
|
||||
ruff format --check $(check_dirs)
|
||||
@@ -11,4 +18,4 @@ style:
|
||||
ruff format $(check_dirs)
|
||||
|
||||
test:
|
||||
CUDA_VISIBLE_DEVICES= pytest tests/
|
||||
CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/
|
||||
|
||||
112
README.md
112
README.md
@@ -4,7 +4,7 @@
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llamafactory/)
|
||||
[](#projects-using-llama-factory)
|
||||
[](#projects-using-llama-factory)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@@ -12,6 +12,7 @@
|
||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
|
||||
|
||||
[](https://trendshift.io/repositories/4535)
|
||||
|
||||
@@ -25,10 +26,18 @@ https://github.com/user-attachments/assets/7c96b465-9df7-45f4-8053-bf03e58386d3
|
||||
|
||||
Choose your path:
|
||||
|
||||
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||
- **PAI-DSW**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **Local machine**: Please refer to [usage](#getting-started)
|
||||
- **Documentation (WIP)**: https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||
- **Colab**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||
- **Local machine**: Please refer to [usage](#getting-started)
|
||||
- **PAI-DSW**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
|
||||
- **Amazon SageMaker**: [Blog](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
|
||||
|
||||
Recent activities:
|
||||
|
||||
- **2024/10/18-2024/11/30**: Build a personal tour guide bot using PAI+LLaMA Factory. [[website]](https://developer.aliyun.com/topic/llamafactory2)
|
||||
|
||||
> [!NOTE]
|
||||
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
@@ -72,6 +81,10 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Changelog
|
||||
|
||||
[24/10/09] We supported downloading pre-trained models and datasets from the **[Modelers Hub](https://modelers.cn/models)**. See [this tutorial](#download-from-modelers-hub) for usage.
|
||||
|
||||
[24/09/19] We support fine-tuning the **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** models.
|
||||
|
||||
[24/08/30] We support fine-tuning the **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** models. Thank [@simonJJJ](https://github.com/simonJJJ)'s PR.
|
||||
|
||||
[24/08/27] We support **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**. Try `enable_liger_kernel: true` for efficient training.
|
||||
@@ -128,7 +141,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||
|
||||
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#download-from-modelscope-hub) for usage.
|
||||
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)**. See [this tutorial](#download-from-modelscope-hub) for usage.
|
||||
|
||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `neftune_noise_alpha: 5` argument to activate NEFTune.
|
||||
|
||||
@@ -160,34 +173,40 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Model size | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | --------- |
|
||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
| Model | Model size | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | ---------------- |
|
||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/14B | phi |
|
||||
| [Phi-3-small](https://huggingface.co/microsoft) | 7B | phi_small |
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
|
||||
@@ -356,7 +375,7 @@ cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]"
|
||||
```
|
||||
|
||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, quality
|
||||
Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel, bitsandbytes, hqq, eetq, gptq, awq, aqlm, vllm, galore, badam, adam-mini, qwen, modelscope, openmind, quality
|
||||
|
||||
> [!TIP]
|
||||
> Use `pip install --no-deps -e .` to resolve package conflicts.
|
||||
@@ -408,7 +427,7 @@ Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaij
|
||||
|
||||
### Data Preparation
|
||||
|
||||
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope hub or load the dataset in local disk.
|
||||
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope / Modelers hub or load the dataset in local disk.
|
||||
|
||||
> [!NOTE]
|
||||
> Please update `data/dataset_info.json` to use your custom dataset.
|
||||
@@ -476,6 +495,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
|
||||
docker run -dit --gpus=all \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./om_cache:/root/.cache/openmind \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-p 7860:7860 \
|
||||
@@ -500,6 +520,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
|
||||
docker run -dit \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./om_cache:/root/.cache/openmind \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||
@@ -533,6 +554,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
|
||||
docker run -dit \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./om_cache:/root/.cache/openmind \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-v ./saves:/app/saves \
|
||||
@@ -553,6 +575,7 @@ docker exec -it llamafactory bash
|
||||
|
||||
- `hf_cache`: Utilize Hugging Face cache on the host machine. Reassignable if a cache already exists in a different directory.
|
||||
- `ms_cache`: Similar to Hugging Face cache but for ModelScope users.
|
||||
- `om_cache`: Similar to Hugging Face cache but for Modelers users.
|
||||
- `data`: Place datasets on this dir of the host machine so that they can be selected on LLaMA Board GUI.
|
||||
- `output`: Set export dir to this location so that the merged result can be accessed directly on the host machine.
|
||||
|
||||
@@ -566,6 +589,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||
|
||||
> [!TIP]
|
||||
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
|
||||
>
|
||||
> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py)
|
||||
|
||||
### Download from ModelScope Hub
|
||||
|
||||
@@ -577,6 +602,16 @@ export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
||||
|
||||
Train the model by specifying a model ID of the ModelScope Hub as the `model_name_or_path`. You can find a full list of model IDs at [ModelScope Hub](https://modelscope.cn/models), e.g., `LLM-Research/Meta-Llama-3-8B-Instruct`.
|
||||
|
||||
### Download from Modelers Hub
|
||||
|
||||
You can also use Modelers Hub to download models and datasets.
|
||||
|
||||
```bash
|
||||
export USE_OPENMIND_HUB=1 # `set USE_OPENMIND_HUB=1` for Windows
|
||||
```
|
||||
|
||||
Train the model by specifying a model ID of the Modelers Hub as the `model_name_or_path`. You can find a full list of model IDs at [Modelers Hub](https://modelers.cn/models), e.g., `TeleAI/TeleChat-7B-pt`.
|
||||
|
||||
### Use W&B Logger
|
||||
|
||||
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
|
||||
@@ -675,16 +710,19 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
|
||||
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
|
||||
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
|
||||
1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**: A large language model specialized in generate metadata for stable diffusion. [[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**: A multimodal large language model specialized in Chinese medical domain, based on LLaVA-1.5-7B.
|
||||
1. **[AutoRE](https://github.com/THUDM/AutoRE)**: A document-level relation extraction system based on large language models.
|
||||
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**: SDKs for fine-tuning LLMs on Windows PC for NVIDIA RTX.
|
||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
||||
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
@@ -692,7 +730,7 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
108
README_zh.md
108
README_zh.md
@@ -4,7 +4,7 @@
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llamafactory/)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://twitter.com/llamafactory_ai)
|
||||
@@ -12,6 +12,7 @@
|
||||
[](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
[](https://modelscope.cn/studios/hiyouga/LLaMA-Board)
|
||||
[](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
|
||||
|
||||
[](https://trendshift.io/repositories/4535)
|
||||
|
||||
@@ -25,11 +26,19 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||
|
||||
选择你的打开方式:
|
||||
|
||||
- **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||
- **PAI-DSW**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
- **本地机器**:请见[如何使用](#如何使用)
|
||||
- **入门教程**:https://zhuanlan.zhihu.com/p/695287607
|
||||
- **框架文档**:https://llamafactory.readthedocs.io/zh-cn/latest/
|
||||
- **Colab**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||
- **本地机器**:请见[如何使用](#如何使用)
|
||||
- **PAI-DSW**:[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)
|
||||
- **Amazon SageMaker**:[博客](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
|
||||
|
||||
近期活动:
|
||||
|
||||
- **2024/10/18-2024/11/30**:使用 PAI+LLaMA Factory 构建个性化导游机器人。[[活动页面]](https://developer.aliyun.com/topic/llamafactory2)
|
||||
|
||||
> [!NOTE]
|
||||
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
|
||||
|
||||
## 目录
|
||||
|
||||
@@ -73,6 +82,10 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||
|
||||
## 更新日志
|
||||
|
||||
[24/10/09] 我们支持了从 **[魔乐社区](https://modelers.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#从魔乐社区下载)。
|
||||
|
||||
[24/09/19] 我们支持了 **[Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/)** 模型的微调。
|
||||
|
||||
[24/08/30] 我们支持了 **[Qwen2-VL](https://qwenlm.github.io/blog/qwen2-vl/)** 模型的微调。感谢 [@simonJJJ](https://github.com/simonJJJ) 的 PR。
|
||||
|
||||
[24/08/27] 我们支持了 **[Liger Kernel](https://github.com/linkedin/Liger-Kernel)**。请使用 `enable_liger_kernel: true` 来加速训练。
|
||||
@@ -161,34 +174,39 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||
|
||||
## 模型
|
||||
|
||||
| 模型名 | 模型大小 | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | --------- |
|
||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
| 模型名 | 模型大小 | Template |
|
||||
| ----------------------------------------------------------------- | -------------------------------- | ---------------- |
|
||||
| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
|
||||
| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
|
||||
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
|
||||
| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
|
||||
| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
|
||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
|
||||
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
|
||||
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
|
||||
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
|
||||
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
|
||||
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
|
||||
| [PaliGemma](https://huggingface.co/google) | 3B | paligemma |
|
||||
| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
|
||||
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl |
|
||||
| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
|
||||
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
|
||||
| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi |
|
||||
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
|
||||
| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
|
||||
|
||||
> [!NOTE]
|
||||
> 对于所有“基座”(Base)模型,`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Instruct/Chat)模型请务必使用**对应的模板**。
|
||||
@@ -357,7 +375,7 @@ cd LLaMA-Factory
|
||||
pip install -e ".[torch,metrics]"
|
||||
```
|
||||
|
||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、quality
|
||||
可选的额外依赖项:torch、torch-npu、metrics、deepspeed、liger-kernel、bitsandbytes、hqq、eetq、gptq、awq、aqlm、vllm、galore、badam、adam-mini、qwen、modelscope、openmind、quality
|
||||
|
||||
> [!TIP]
|
||||
> 遇到包冲突时,可使用 `pip install --no-deps -e .` 解决。
|
||||
@@ -409,7 +427,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
|
||||
### 数据准备
|
||||
|
||||
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope 上的数据集或加载本地数据集。
|
||||
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope / Modelers 上的数据集或加载本地数据集。
|
||||
|
||||
> [!NOTE]
|
||||
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。
|
||||
@@ -477,6 +495,7 @@ docker build -f ./docker/docker-cuda/Dockerfile \
|
||||
docker run -dit --gpus=all \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./om_cache:/root/.cache/openmind \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-p 7860:7860 \
|
||||
@@ -501,6 +520,7 @@ docker build -f ./docker/docker-npu/Dockerfile \
|
||||
docker run -dit \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./om_cache:/root/.cache/openmind \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||
@@ -534,6 +554,7 @@ docker build -f ./docker/docker-rocm/Dockerfile \
|
||||
docker run -dit \
|
||||
-v ./hf_cache:/root/.cache/huggingface \
|
||||
-v ./ms_cache:/root/.cache/modelscope \
|
||||
-v ./om_cache:/root/.cache/openmind \
|
||||
-v ./data:/app/data \
|
||||
-v ./output:/app/output \
|
||||
-v ./saves:/app/saves \
|
||||
@@ -554,6 +575,7 @@ docker exec -it llamafactory bash
|
||||
|
||||
- `hf_cache`:使用宿主机的 Hugging Face 缓存文件夹,允许更改为新的目录。
|
||||
- `ms_cache`:类似 Hugging Face 缓存文件夹,为 ModelScope 用户提供。
|
||||
- `om_cache`:类似 Hugging Face 缓存文件夹,为 Modelers 用户提供。
|
||||
- `data`:宿主机中存放数据集的文件夹路径。
|
||||
- `output`:将导出目录设置为该路径后,即可在宿主机中访问导出后的模型。
|
||||
|
||||
@@ -567,6 +589,8 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
||||
|
||||
> [!TIP]
|
||||
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
|
||||
>
|
||||
> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py)
|
||||
|
||||
### 从魔搭社区下载
|
||||
|
||||
@@ -578,6 +602,16 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||
|
||||
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔搭社区](https://modelscope.cn/models)查看所有可用的模型,例如 `LLM-Research/Meta-Llama-3-8B-Instruct`。
|
||||
|
||||
### 从魔乐社区下载
|
||||
|
||||
您也可以通过下述方法,使用魔乐社区下载数据集和模型。
|
||||
|
||||
```bash
|
||||
export USE_OPENMIND_HUB=1 # Windows 使用 `set USE_OPENMIND_HUB=1`
|
||||
```
|
||||
|
||||
将 `model_name_or_path` 设置为模型 ID 来加载对应的模型。在[魔乐社区](https://modelers.cn/models)查看所有可用的模型,例如 `TeleAI/TeleChat-7B-pt`。
|
||||
|
||||
### 使用 W&B 面板
|
||||
|
||||
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
|
||||
@@ -676,16 +710,18 @@ run_name: test_run # 可选
|
||||
1. Zeng et al. Perceive, Reflect, and Plan: Designing LLM Agent for Goal-Directed City Navigation without Instructions. 2024. [[arxiv]](https://arxiv.org/abs/2408.04168)
|
||||
1. Xia et al. Using Pre-trained Language Model for Accurate ESG Prediction. FinNLP 2024. [[paper]](https://aclanthology.org/2024.finnlp-2.1/)
|
||||
1. Liang et al. I-SHEEP: Self-Alignment of LLM from Scratch through an Iterative Self-Enhancement Paradigm. 2024. [[arxiv]](https://arxiv.org/abs/2408.08072)
|
||||
1. Bai et al. Aligning Large Language Model with Direct Multi-Preference Optimization for Recommendation. CIKM 2024. [[paper]](https://dl.acm.org/doi/10.1145/3627673.3679611)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
1. **[Sunsimiao](https://github.com/X-D-Lab/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[🤗Demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||
1. **[Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3)**:一个用于生成 Stable Diffusion 提示词的大型语言模型。[[demo]](https://huggingface.co/spaces/Nekochu/Luminia-13B_SD_Prompt)
|
||||
1. **[Chinese-LLaVA-Med](https://github.com/BUAADreamer/Chinese-LLaVA-Med)**:中文多模态医学大模型,基于 LLaVA-1.5-7B 在中文多模态医疗数据上微调而得。
|
||||
1. **[AutoRE](https://github.com/THUDM/AutoRE)**:基于大语言模型的文档级关系抽取系统。
|
||||
1. **[NVIDIA RTX AI Toolkit](https://github.com/NVIDIA/RTX-AI-Toolkit)**:在 Windows 主机上利用英伟达 RTX 设备进行大型语言模型微调的开发包。
|
||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**:一个低代码构建多 Agent 大模型应用的开发工具,支持基于 LLaMA Factory 的模型微调.
|
||||
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||
|
||||
</details>
|
||||
|
||||
@@ -693,7 +729,7 @@ run_name: test_run # 可选
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan 2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Command R](https://cohere.com/c4ai-cc-by-nc-license) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [GLM-4](https://huggingface.co/THUDM/glm-4-9b/blob/main/LICENSE) / [Index](https://huggingface.co/IndexTeam/Index-1.9B/blob/main/LICENSE) / [InternLM2](https://github.com/InternLM/InternLM#license) / [Llama](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [Llama 2 (LLaVA-1.5)](https://ai.meta.com/llama/license/) / [Llama 3](https://llama.meta.com/llama3/license/) / [MiniCPM](https://github.com/OpenBMB/MiniCPM/blob/main/MiniCPM%20Model%20License.md) / [Mistral/Mixtral/Pixtral](LICENSE) / [OLMo](LICENSE) / [Phi-1.5/Phi-2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/LICENSE) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [StarCoder 2](https://huggingface.co/spaces/bigcode/bigcode-model-license-agreement) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yi-1.5](LICENSE) / [Yuan 2](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## 引用
|
||||
|
||||
|
||||
1630
assets/benchmark.svg
1630
assets/benchmark.svg
File diff suppressed because it is too large
Load Diff
|
Before Width: | Height: | Size: 29 KiB After Width: | Height: | Size: 28 KiB |
@@ -17,9 +17,9 @@ _CITATION = """\
|
||||
}
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "{}/datasets/BelleGroup/multiturn_chat_0.8M".format(_HF_ENDPOINT)
|
||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M"
|
||||
_LICENSE = "gpl-3.0"
|
||||
_URL = "{}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json".format(_HF_ENDPOINT)
|
||||
_URL = f"{_HF_ENDPOINT}/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
|
||||
|
||||
|
||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
@@ -38,7 +38,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": file_path})]
|
||||
|
||||
def _generate_examples(self, filepath: str):
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
for key, row in enumerate(f):
|
||||
data = json.loads(row)
|
||||
conversations = []
|
||||
|
||||
@@ -8,9 +8,9 @@ import datasets
|
||||
_HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
|
||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
||||
_CITATION = ""
|
||||
_HOMEPAGE = "{}/datasets/Anthropic/hh-rlhf".format(_HF_ENDPOINT)
|
||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf"
|
||||
_LICENSE = "mit"
|
||||
_URL = "{}/datasets/Anthropic/hh-rlhf/resolve/main/".format(_HF_ENDPOINT)
|
||||
_URL = f"{_HF_ENDPOINT}/datasets/Anthropic/hh-rlhf/resolve/main/"
|
||||
_URLS = {
|
||||
"train": [
|
||||
_URL + "harmless-base/train.jsonl.gz",
|
||||
@@ -53,7 +53,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
key = 0
|
||||
for filepath in filepaths:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
for row in f:
|
||||
data = json.loads(row)
|
||||
chosen = data["chosen"]
|
||||
|
||||
@@ -20,9 +20,9 @@ _CITATION = """\
|
||||
}
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "{}/datasets/stingning/ultrachat".format(_HF_ENDPOINT)
|
||||
_HOMEPAGE = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat"
|
||||
_LICENSE = "cc-by-nc-4.0"
|
||||
_BASE_DATA_URL = "{}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl".format(_HF_ENDPOINT)
|
||||
_BASE_DATA_URL = f"{_HF_ENDPOINT}/datasets/stingning/ultrachat/resolve/main/train_{{idx}}.jsonl"
|
||||
|
||||
|
||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
@@ -42,7 +42,7 @@ class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
for filepath in filepaths:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
with open(filepath, encoding="utf-8") as f:
|
||||
for row in f:
|
||||
try:
|
||||
data = json.loads(row)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,6 +1,7 @@
|
||||
# Use the NVIDIA official image with PyTorch 2.3.0
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-02.html
|
||||
FROM nvcr.io/nvidia/pytorch:24.02-py3
|
||||
# Default use the NVIDIA official image with PyTorch 2.3.0
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html
|
||||
ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.02-py3
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
# Define environments
|
||||
ENV MAX_JOBS=4
|
||||
@@ -12,6 +13,9 @@ ARG INSTALL_BNB=false
|
||||
ARG INSTALL_VLLM=false
|
||||
ARG INSTALL_DEEPSPEED=false
|
||||
ARG INSTALL_FLASHATTN=false
|
||||
ARG INSTALL_LIGER_KERNEL=false
|
||||
ARG INSTALL_HQQ=false
|
||||
ARG INSTALL_EETQ=false
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
|
||||
# Set the working directory
|
||||
@@ -38,6 +42,15 @@ RUN EXTRA_PACKAGES="metrics"; \
|
||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_HQQ" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_EETQ" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},eetq"; \
|
||||
fi; \
|
||||
pip install -e ".[$EXTRA_PACKAGES]"
|
||||
|
||||
# Rebuild flash attention
|
||||
|
||||
@@ -8,11 +8,15 @@ services:
|
||||
INSTALL_VLLM: false
|
||||
INSTALL_DEEPSPEED: false
|
||||
INSTALL_FLASHATTN: false
|
||||
INSTALL_LIGER_KERNEL: false
|
||||
INSTALL_HQQ: false
|
||||
INSTALL_EETQ: false
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
container_name: llamafactory
|
||||
volumes:
|
||||
- ../../hf_cache:/root/.cache/huggingface
|
||||
- ../../ms_cache:/root/.cache/modelscope
|
||||
- ../../om_cache:/root/.cache/openmind
|
||||
- ../../data:/app/data
|
||||
- ../../output:/app/output
|
||||
ports:
|
||||
@@ -20,6 +24,7 @@ services:
|
||||
- "8000:8000"
|
||||
ipc: host
|
||||
tty: true
|
||||
shm_size: '16gb'
|
||||
stdin_open: true
|
||||
command: bash
|
||||
deploy:
|
||||
|
||||
@@ -10,6 +10,7 @@ services:
|
||||
volumes:
|
||||
- ../../hf_cache:/root/.cache/huggingface
|
||||
- ../../ms_cache:/root/.cache/modelscope
|
||||
- ../../om_cache:/root/.cache/openmind
|
||||
- ../../data:/app/data
|
||||
- ../../output:/app/output
|
||||
- /usr/local/dcmi:/usr/local/dcmi
|
||||
@@ -21,6 +22,7 @@ services:
|
||||
- "8000:8000"
|
||||
ipc: host
|
||||
tty: true
|
||||
shm_size: '16gb'
|
||||
stdin_open: true
|
||||
command: bash
|
||||
devices:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM hardandheavy/transformers-rocm:2.1.0
|
||||
FROM hardandheavy/transformers-rocm:2.2.0
|
||||
|
||||
# Define environments
|
||||
ENV MAX_JOBS=4
|
||||
@@ -10,6 +10,8 @@ ARG INSTALL_BNB=false
|
||||
ARG INSTALL_VLLM=false
|
||||
ARG INSTALL_DEEPSPEED=false
|
||||
ARG INSTALL_FLASHATTN=false
|
||||
ARG INSTALL_LIGER_KERNEL=false
|
||||
ARG INSTALL_HQQ=false
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
|
||||
# Set the working directory
|
||||
@@ -36,6 +38,12 @@ RUN EXTRA_PACKAGES="metrics"; \
|
||||
if [ "$INSTALL_DEEPSPEED" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_LIGER_KERNEL" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},liger-kernel"; \
|
||||
fi; \
|
||||
if [ "$INSTALL_HQQ" == "true" ]; then \
|
||||
EXTRA_PACKAGES="${EXTRA_PACKAGES},hqq"; \
|
||||
fi; \
|
||||
pip install -e ".[$EXTRA_PACKAGES]"
|
||||
|
||||
# Rebuild flash attention
|
||||
|
||||
@@ -8,11 +8,14 @@ services:
|
||||
INSTALL_VLLM: false
|
||||
INSTALL_DEEPSPEED: false
|
||||
INSTALL_FLASHATTN: false
|
||||
INSTALL_LIGER_KERNEL: false
|
||||
INSTALL_HQQ: false
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
container_name: llamafactory
|
||||
volumes:
|
||||
- ../../hf_cache:/root/.cache/huggingface
|
||||
- ../../ms_cache:/root/.cache/modelscope
|
||||
- ../../om_cache:/root/.cache/openmind
|
||||
- ../../data:/app/data
|
||||
- ../../output:/app/output
|
||||
- ../../saves:/app/saves
|
||||
@@ -21,6 +24,7 @@ services:
|
||||
- "8000:8000"
|
||||
ipc: host
|
||||
tty: true
|
||||
shm_size: '16gb'
|
||||
stdin_open: true
|
||||
command: bash
|
||||
devices:
|
||||
|
||||
@@ -158,5 +158,4 @@ class MMLU(datasets.GeneratorBasedBuilder):
|
||||
df = pd.read_csv(filepath, header=None)
|
||||
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
||||
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
yield i, instance
|
||||
yield from enumerate(df.to_dict(orient="records"))
|
||||
|
||||
@@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
||||
#### Supervised Fine-Tuning on Multiple Nodes
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
||||
|
||||
@@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
||||
#### 多机指令监督微调
|
||||
|
||||
```bash
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||
```
|
||||
|
||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||
|
||||
@@ -10,7 +10,7 @@ use_adam_mini: true
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: qwen
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -15,7 +15,7 @@ badam_verbose: 2
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -11,7 +11,7 @@ lora_target: all
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -14,7 +14,7 @@ galore_scale: 2.0
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -12,7 +12,7 @@ use_llama_pro: true
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -11,7 +11,7 @@ loraplus_lr_ratio: 16.0
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -10,7 +10,7 @@ mixture_of_depths: convert
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -13,7 +13,7 @@ pissa_convert: true
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -9,7 +9,7 @@ finetuning_type: full
|
||||
### dataset
|
||||
eval_dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -10,7 +10,7 @@ deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -10,7 +10,7 @@ deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
### dataset
|
||||
dataset: mllm_demo,identity
|
||||
template: qwen2_vl
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -12,7 +12,7 @@ pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
||||
### dataset
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -11,7 +11,7 @@ pref_beta: 0.1
|
||||
### dataset
|
||||
dataset: kto_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -11,7 +11,7 @@ lora_target: all
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -10,7 +10,7 @@ finetuning_type: lora
|
||||
### dataset
|
||||
eval_dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 50
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -9,7 +9,7 @@ lora_target: all
|
||||
|
||||
### dataset
|
||||
dataset: c4_demo
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -10,7 +10,7 @@ lora_target: all
|
||||
### dataset
|
||||
dataset: dpo_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -10,7 +10,7 @@ lora_target: all
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -11,7 +11,7 @@ deepspeed: examples/deepspeed/ds_z0_config.json
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -11,7 +11,7 @@ deepspeed: examples/deepspeed/ds_z3_config.json
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -10,7 +10,7 @@ lora_target: all
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -10,7 +10,7 @@ lora_target: all
|
||||
### dataset
|
||||
dataset: mllm_demo
|
||||
template: llava
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -12,7 +12,7 @@ pref_loss: sigmoid # choices: [sigmoid (dpo), orpo, simpo]
|
||||
### dataset
|
||||
dataset: rlhf_v
|
||||
template: qwen2_vl
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -10,7 +10,7 @@ lora_target: all
|
||||
### dataset
|
||||
dataset: mllm_demo,identity # video: mllm_video_demo
|
||||
template: qwen2_vl
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -10,7 +10,7 @@ lora_target: all
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -10,7 +10,7 @@ lora_target: all
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -10,7 +10,7 @@ lora_target: all
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -12,7 +12,7 @@ lora_target: all
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 1024
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
transformers>=4.41.2,<=4.45.0
|
||||
datasets>=2.16.0,<=2.21.0
|
||||
accelerate>=0.30.1,<=0.33.0
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
datasets>=2.16.0,<=3.1.0
|
||||
accelerate>=0.34.0,<=1.0.1
|
||||
peft>=0.11.1,<=0.12.0
|
||||
trl>=0.8.6,<=0.9.6
|
||||
gradio>=4.0.0
|
||||
gradio>=4.0.0,<5.0.0
|
||||
pandas>=2.0.0
|
||||
scipy
|
||||
einops
|
||||
@@ -19,3 +19,5 @@ fire
|
||||
packaging
|
||||
pyyaml
|
||||
numpy<2.0.0
|
||||
av
|
||||
tyro<0.9.0
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Microsoft's DeepSpeed library.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 imoneoi and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the imoneoi's OpenChat library.
|
||||
@@ -74,7 +73,7 @@ def calculate_lr(
|
||||
elif stage == "sft":
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
||||
else:
|
||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
||||
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
valid_tokens, total_tokens = 0, 0
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -100,7 +99,7 @@ def compute_device_flops(world_size: int) -> float:
|
||||
elif "4090" in device_name:
|
||||
return 98 * 1e12 * world_size
|
||||
else:
|
||||
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
||||
raise NotImplementedError(f"Device not supported: {device_name}.")
|
||||
|
||||
|
||||
def calculate_mfu(
|
||||
@@ -140,10 +139,10 @@ def calculate_mfu(
|
||||
"bf16": True,
|
||||
}
|
||||
if deepspeed_stage in [2, 3]:
|
||||
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
|
||||
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
|
||||
|
||||
run_exp(args)
|
||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
|
||||
result = json.load(f)
|
||||
|
||||
if dist.is_initialized():
|
||||
@@ -157,7 +156,7 @@ def calculate_mfu(
|
||||
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
||||
/ compute_device_flops(world_size)
|
||||
)
|
||||
print("MFU: {:.2f}%".format(mfu_value * 100))
|
||||
print(f"MFU: {mfu_value * 100:.2f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -100,7 +99,7 @@ def calculate_ppl(
|
||||
tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX, train_on_prompt=train_on_prompt
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Stage does not supported: {}.".format(stage))
|
||||
raise NotImplementedError(f"Stage does not supported: {stage}.")
|
||||
|
||||
dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
|
||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
@@ -125,8 +124,8 @@ def calculate_ppl(
|
||||
with open(save_name, "w", encoding="utf-8") as f:
|
||||
json.dump(perplexities, f, indent=2)
|
||||
|
||||
print("Average perplexity is {:.2f}".format(total_ppl / len(perplexities)))
|
||||
print("Perplexities have been saved at {}.".format(save_name))
|
||||
print(f"Average perplexity is {total_ppl / len(perplexities):.2f}")
|
||||
print(f"Perplexities have been saved at {save_name}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -61,7 +60,7 @@ def length_cdf(
|
||||
for length, count in length_tuples:
|
||||
count_accu += count
|
||||
prob_accu += count / total_num * 100
|
||||
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
|
||||
print(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Tencent's LLaMA-Pro library.
|
||||
@@ -40,7 +39,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def change_name(name: str, old_index: int, new_index: int) -> str:
|
||||
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
|
||||
return name.replace(f".{old_index:d}.", f".{new_index:d}.")
|
||||
|
||||
|
||||
def block_expansion(
|
||||
@@ -76,27 +75,27 @@ def block_expansion(
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if num_layers % num_expand != 0:
|
||||
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
|
||||
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
|
||||
|
||||
split = num_layers // num_expand
|
||||
layer_cnt = 0
|
||||
output_state_dict = OrderedDict()
|
||||
for i in range(num_layers):
|
||||
for key, value in state_dict.items():
|
||||
if ".{:d}.".format(i) in key:
|
||||
if f".{i:d}." in key:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = value
|
||||
|
||||
print("Add layer {} copied from layer {}".format(layer_cnt, i))
|
||||
print(f"Add layer {layer_cnt} copied from layer {i}")
|
||||
layer_cnt += 1
|
||||
if (i + 1) % split == 0:
|
||||
for key, value in state_dict.items():
|
||||
if ".{:d}.".format(i) in key:
|
||||
if f".{i:d}." in key:
|
||||
if "down_proj" in key or "o_proj" in key:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
|
||||
else:
|
||||
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
|
||||
|
||||
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
|
||||
print(f"Add layer {layer_cnt} expanded from layer {i}")
|
||||
layer_cnt += 1
|
||||
|
||||
for key, value in state_dict.items():
|
||||
@@ -113,17 +112,17 @@ def block_expansion(
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
||||
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
print("- Fine-tune this model with:")
|
||||
print("model_name_or_path: {}".format(output_dir))
|
||||
print(f"model_name_or_path: {output_dir}")
|
||||
print("finetuning_type: freeze")
|
||||
print("freeze_trainable_layers: {}".format(num_expand))
|
||||
print(f"freeze_trainable_layers: {num_expand}")
|
||||
print("use_llama_pro: true")
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -63,16 +62,16 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
||||
print(f"Model weights saved in {os.path.join(output_dir, WEIGHTS_NAME)}")
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
|
||||
def save_config(input_dir: str, output_dir: str):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||
@@ -82,7 +81,7 @@ def save_config(input_dir: str, output_dir: str):
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||
|
||||
|
||||
def llamafy_baichuan2(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -86,7 +85,7 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
||||
elif "lm_head" in key:
|
||||
llama2_state_dict[key] = value
|
||||
else:
|
||||
raise KeyError("Unable to process key {}".format(key))
|
||||
raise KeyError(f"Unable to process key {key}")
|
||||
|
||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||
@@ -98,18 +97,18 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
|
||||
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
|
||||
else:
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
return str(torch_dtype).replace("torch.", "")
|
||||
|
||||
|
||||
def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), encoding="utf-8") as f:
|
||||
qwen_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
||||
@@ -135,7 +134,7 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
print(f"Model config saved in {os.path.join(output_dir, CONFIG_NAME)}")
|
||||
|
||||
|
||||
def llamafy_qwen(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is based on the HuggingFace's PEFT library.
|
||||
@@ -70,19 +69,19 @@ def quantize_loftq(
|
||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
|
||||
peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
|
||||
print("Adapter weights saved in {}".format(loftq_dir))
|
||||
print(f"Adapter weights saved in {loftq_dir}")
|
||||
|
||||
# Save base model
|
||||
base_model: "PreTrainedModel" = peft_model.unload()
|
||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
print("- Fine-tune this model with:")
|
||||
print("model_name_or_path: {}".format(output_dir))
|
||||
print("adapter_name_or_path: {}".format(loftq_dir))
|
||||
print(f"model_name_or_path: {output_dir}")
|
||||
print(f"adapter_name_or_path: {loftq_dir}")
|
||||
print("finetuning_type: lora")
|
||||
print("quantization_bit: {}".format(loftq_bits))
|
||||
print(f"quantization_bit: {loftq_bits}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is based on the HuggingFace's PEFT library.
|
||||
@@ -54,7 +53,7 @@ def quantize_pissa(
|
||||
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
|
||||
lora_dropout=lora_dropout,
|
||||
target_modules=lora_target,
|
||||
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
|
||||
init_lora_weights="pissa" if pissa_iter == -1 else f"pissa_niter_{pissa_iter}",
|
||||
)
|
||||
|
||||
# Init PiSSA model
|
||||
@@ -65,17 +64,17 @@ def quantize_pissa(
|
||||
setattr(peft_model.peft_config["default"], "base_model_name_or_path", os.path.abspath(output_dir))
|
||||
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
|
||||
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
|
||||
print("Adapter weights saved in {}".format(pissa_dir))
|
||||
print(f"Adapter weights saved in {pissa_dir}")
|
||||
|
||||
# Save base model
|
||||
base_model: "PreTrainedModel" = peft_model.unload()
|
||||
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
print(f"Model weights saved in {output_dir}")
|
||||
|
||||
print("- Fine-tune this model with:")
|
||||
print("model_name_or_path: {}".format(output_dir))
|
||||
print("adapter_name_or_path: {}".format(pissa_dir))
|
||||
print(f"model_name_or_path: {output_dir}")
|
||||
print(f"adapter_name_or_path: {pissa_dir}")
|
||||
print("finetuning_type: lora")
|
||||
print("pissa_init: false")
|
||||
print("pissa_convert: true")
|
||||
|
||||
65
scripts/test_image.py
Normal file
65
scripts/test_image.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from openai import OpenAI
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
|
||||
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
||||
|
||||
|
||||
def main():
|
||||
client = OpenAI(
|
||||
api_key="{}".format(os.environ.get("API_KEY", "0")),
|
||||
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
|
||||
)
|
||||
messages = []
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Output the color and number of each box."},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/boxes.png"},
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
result = client.chat.completions.create(messages=messages, model="test")
|
||||
messages.append(result.choices[0].message)
|
||||
print("Round 1:", result.choices[0].message.content)
|
||||
# The image shows a pyramid of colored blocks with numbers on them. Here are the colors and numbers of ...
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What kind of flower is this?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-VL/flowers.jpg"},
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
result = client.chat.completions.create(messages=messages, model="test")
|
||||
messages.append(result.choices[0].message)
|
||||
print("Round 2:", result.choices[0].message.content)
|
||||
# The image shows a cluster of forget-me-not flowers. Forget-me-nots are small ...
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
11
setup.py
11
setup.py
@@ -20,7 +20,7 @@ from setuptools import find_packages, setup
|
||||
|
||||
|
||||
def get_version() -> str:
|
||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join("src", "llamafactory", "extras", "env.py"), encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
pattern = r"{}\W*=\W*\"([^\"]+)\"".format("VERSION")
|
||||
(version,) = re.findall(pattern, file_content)
|
||||
@@ -28,7 +28,7 @@ def get_version() -> str:
|
||||
|
||||
|
||||
def get_requires() -> List[str]:
|
||||
with open("requirements.txt", "r", encoding="utf-8") as f:
|
||||
with open("requirements.txt", encoding="utf-8") as f:
|
||||
file_content = f.read()
|
||||
lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
|
||||
return lines
|
||||
@@ -54,13 +54,14 @@ extra_require = {
|
||||
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
||||
"awq": ["autoawq"],
|
||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||
"vllm": ["vllm>=0.4.3,<=0.6.0"],
|
||||
"vllm": ["vllm>=0.4.3,<0.6.4"],
|
||||
"galore": ["galore-torch"],
|
||||
"badam": ["badam>=1.2.1"],
|
||||
"adam-mini": ["adam-mini"],
|
||||
"qwen": ["transformers_stream_generator"],
|
||||
"modelscope": ["modelscope"],
|
||||
"dev": ["ruff", "pytest"],
|
||||
"openmind": ["openmind"],
|
||||
"dev": ["pre-commit", "ruff", "pytest"],
|
||||
}
|
||||
|
||||
|
||||
@@ -71,7 +72,7 @@ def main():
|
||||
author="hiyouga",
|
||||
author_email="hiyouga" "@" "buaa.edu.cn",
|
||||
description="Easy-to-use LLM fine-tuning framework",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
||||
license="Apache 2.0 License",
|
||||
|
||||
@@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
||||
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.getenv("API_PORT", "8000"))
|
||||
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||
uvicorn.run(app, host=api_host, port=api_port)
|
||||
|
||||
|
||||
|
||||
@@ -20,17 +20,17 @@ Level:
|
||||
|
||||
Dependency graph:
|
||||
main:
|
||||
transformers>=4.41.2,<=4.45.0
|
||||
datasets>=2.16.0,<=2.21.0
|
||||
accelerate>=0.30.1,<=0.33.0
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
datasets>=2.16.0,<=3.1.0
|
||||
accelerate>=0.34.0,<=1.0.1
|
||||
peft>=0.11.1,<=0.12.0
|
||||
trl>=0.8.6,<=0.9.6
|
||||
attention:
|
||||
transformers>=4.42.4 (gemma+fa2)
|
||||
longlora:
|
||||
transformers>=4.41.2,<=4.45.0
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
packing:
|
||||
transformers>=4.41.2,<=4.45.0
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
|
||||
Disable version checking: DISABLE_VERSION_CHECK=1
|
||||
Enable VRAM recording: RECORD_VRAM=1
|
||||
@@ -38,6 +38,7 @@ Force check imports: FORCE_CHECK_IMPORTS=1
|
||||
Force using torchrun: FORCE_TORCHRUN=1
|
||||
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
||||
Use modelscope: USE_MODELSCOPE_HUB=1
|
||||
Use openmind: USE_OPENMIND_HUB=1
|
||||
"""
|
||||
|
||||
from .extras.env import VERSION
|
||||
|
||||
@@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
|
||||
|
||||
|
||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
|
||||
root_path = os.getenv("FASTAPI_ROOT_PATH", "")
|
||||
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
api_key = os.environ.get("API_KEY", None)
|
||||
api_key = os.getenv("API_KEY")
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||
@@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def list_models():
|
||||
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
|
||||
model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
@app.post(
|
||||
@@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
def run_api() -> None:
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
||||
print("Visit http://localhost:{}/docs for API document.".format(api_port))
|
||||
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||
api_port = int(os.getenv("API_PORT", "8000"))
|
||||
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||
uvicorn.run(app, host=api_host, port=api_port)
|
||||
|
||||
@@ -21,7 +21,7 @@ import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||
from .common import dictify, jsonify
|
||||
from .protocol import (
|
||||
@@ -57,7 +57,7 @@ if TYPE_CHECKING:
|
||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
ROLE_MAPPING = {
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
@@ -69,8 +69,8 @@ ROLE_MAPPING = {
|
||||
|
||||
def _process_request(
|
||||
request: "ChatCompletionRequest",
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
|
||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
|
||||
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
||||
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
@@ -84,7 +84,7 @@ def _process_request(
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
input_messages = []
|
||||
image = None
|
||||
images = []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
@@ -111,7 +111,7 @@ def _process_request(
|
||||
else: # web uri
|
||||
image_stream = requests.get(image_url, stream=True).raw
|
||||
|
||||
image = Image.open(image_stream).convert("RGB")
|
||||
images.append(Image.open(image_stream).convert("RGB"))
|
||||
else:
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||
|
||||
@@ -124,7 +124,7 @@ def _process_request(
|
||||
else:
|
||||
tools = None
|
||||
|
||||
return input_messages, system, tools, image
|
||||
return input_messages, system, tools, images or None
|
||||
|
||||
|
||||
def _create_stream_chat_completion_chunk(
|
||||
@@ -142,13 +142,13 @@ def _create_stream_chat_completion_chunk(
|
||||
async def create_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> "ChatCompletionResponse":
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
input_messages, system, tools, images = _process_request(request)
|
||||
responses = await chat_model.achat(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
images,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
@@ -169,7 +169,7 @@ async def create_chat_completion_response(
|
||||
tool_calls = []
|
||||
for tool in result:
|
||||
function = Function(name=tool[0], arguments=tool[1])
|
||||
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
||||
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
|
||||
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||
finish_reason = Finish.TOOL
|
||||
@@ -193,8 +193,8 @@ async def create_chat_completion_response(
|
||||
async def create_stream_chat_completion_response(
|
||||
request: "ChatCompletionRequest", chat_model: "ChatModel"
|
||||
) -> AsyncGenerator[str, None]:
|
||||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
input_messages, system, tools, image = _process_request(request)
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
input_messages, system, tools, images = _process_request(request)
|
||||
if tools:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||
|
||||
@@ -208,7 +208,7 @@ async def create_stream_chat_completion_response(
|
||||
input_messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
images,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
@@ -229,8 +229,9 @@ async def create_stream_chat_completion_response(
|
||||
async def create_score_evaluation_response(
|
||||
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
||||
) -> "ScoreEvaluationResponse":
|
||||
score_id = f"scoreval-{uuid.uuid4().hex}"
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||
return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)
|
||||
|
||||
@@ -66,8 +66,8 @@ class BaseEngine(ABC):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
@@ -81,8 +81,8 @@ class BaseEngine(ABC):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
|
||||
@@ -53,7 +53,7 @@ class ChatModel:
|
||||
elif model_args.infer_backend == "vllm":
|
||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
else:
|
||||
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
||||
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||
@@ -64,15 +64,15 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Gets a list of responses of the chat model.
|
||||
"""
|
||||
task = asyncio.run_coroutine_threadsafe(
|
||||
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
|
||||
self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
|
||||
)
|
||||
return task.result()
|
||||
|
||||
@@ -81,28 +81,28 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Asynchronously gets a list of responses of the chat model.
|
||||
"""
|
||||
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
|
||||
return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
r"""
|
||||
Gets the response token-by-token of the chat model.
|
||||
"""
|
||||
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
|
||||
generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||
@@ -115,14 +115,14 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
Asynchronously gets the response token-by-token of the chat model.
|
||||
"""
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
|
||||
yield new_token
|
||||
|
||||
def get_scores(
|
||||
|
||||
@@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
|
||||
from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class HuggingfaceEngine(BaseEngine):
|
||||
@@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
|
||||
try:
|
||||
asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
logger.warning("There is no current event loop, creating a new one.")
|
||||
logger.warning_once("There is no current event loop, creating a new one.")
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||
|
||||
@staticmethod
|
||||
def _process_args(
|
||||
@@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
||||
if image is not None:
|
||||
mm_input_dict.update({"images": [image], "imglens": [1]})
|
||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
||||
if images is not None:
|
||||
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||
|
||||
if video is not None:
|
||||
mm_input_dict.update({"videos": [video], "vidlens": [1]})
|
||||
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
|
||||
if videos is not None:
|
||||
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
|
||||
if not any(VIDEO_PLACEHOLDER in message["content"] for message in messages):
|
||||
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
||||
|
||||
messages = template.mm_plugin.process_messages(
|
||||
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
||||
@@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if stop is not None:
|
||||
logger.warning("Stop parameter is not supported by the huggingface engine yet.")
|
||||
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
|
||||
|
||||
generating_args = generating_args.copy()
|
||||
generating_args.update(
|
||||
@@ -164,9 +164,13 @@ class HuggingfaceEngine(BaseEngine):
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
|
||||
for key, value in mm_inputs.items():
|
||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
||||
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
||||
value = torch.stack(value) # assume they have same sizes
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
value = torch.tensor(value)
|
||||
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
@@ -182,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
||||
model,
|
||||
tokenizer,
|
||||
processor,
|
||||
template,
|
||||
generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
@@ -218,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
||||
model,
|
||||
tokenizer,
|
||||
processor,
|
||||
template,
|
||||
generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
@@ -246,29 +270,18 @@ class HuggingfaceEngine(BaseEngine):
|
||||
batch_input: List[str],
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List[float]:
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||
device = getattr(model.pretrained_model, "device", "cuda")
|
||||
inputs = tokenizer(
|
||||
inputs: Dict[str, "torch.Tensor"] = tokenizer(
|
||||
batch_input,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
add_special_tokens=False,
|
||||
).to(device)
|
||||
|
||||
input_ids: torch.Tensor = inputs["input_ids"]
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
|
||||
if getattr(model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
scores = []
|
||||
for i in range(input_ids.size(0)):
|
||||
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
scores.append(values[i, end_index].nan_to_num().item())
|
||||
|
||||
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
|
||||
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||
return scores
|
||||
|
||||
@override
|
||||
@@ -277,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
if not self.can_generate:
|
||||
@@ -294,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
video,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
@@ -308,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if not self.can_generate:
|
||||
@@ -325,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
video,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
|
||||
@@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
|
||||
from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_pillow_available, is_vllm_available
|
||||
from ..model import load_config, load_tokenizer
|
||||
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VllmEngine(BaseEngine):
|
||||
@@ -83,11 +83,13 @@ class VllmEngine(BaseEngine):
|
||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||
"max_lora_rank": model_args.vllm_max_lora_rank,
|
||||
}
|
||||
if isinstance(model_args.vllm_config, dict):
|
||||
engine_args.update(model_args.vllm_config)
|
||||
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
import vllm.model_executor.models.llava
|
||||
|
||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
|
||||
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
||||
|
||||
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
||||
@@ -101,21 +103,28 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncIterator["RequestOutput"]:
|
||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
if image is not None:
|
||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
||||
request_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
if images is not None:
|
||||
if not any(IMAGE_PLACEHOLDER in message["content"] for message in messages):
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
if self.template.mm_plugin.__class__.__name__ == "Qwen2vlPlugin": # temporary solution
|
||||
image_str = f"<|vision_start|>{self.template.mm_plugin.image_token}<|vision_end|>"
|
||||
else:
|
||||
image_str = self.template.mm_plugin.image_token or ""
|
||||
|
||||
paired_messages = [
|
||||
{"role": message["role"], "content": message["content"].replace(IMAGE_PLACEHOLDER, image_str)}
|
||||
for message in messages
|
||||
] + [{"role": "assistant", "content": ""}]
|
||||
system = system or self.generating_args["default_system"]
|
||||
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
||||
prompt_length = len(prompt_ids)
|
||||
|
||||
use_beam_search: bool = self.generating_args["num_beams"] > 1
|
||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
||||
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
||||
@@ -126,6 +135,9 @@ class VllmEngine(BaseEngine):
|
||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||
|
||||
if length_penalty is not None:
|
||||
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
|
||||
|
||||
if "max_new_tokens" in self.generating_args:
|
||||
max_tokens = self.generating_args["max_new_tokens"]
|
||||
elif "max_length" in self.generating_args:
|
||||
@@ -149,27 +161,29 @@ class VllmEngine(BaseEngine):
|
||||
temperature=temperature if temperature is not None else self.generating_args["temperature"],
|
||||
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
||||
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
||||
use_beam_search=use_beam_search,
|
||||
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
|
||||
stop=stop,
|
||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
max_tokens=max_tokens,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
if image is not None: # add image features
|
||||
if not isinstance(image, (str, ImageObject)):
|
||||
raise ValueError("Expected image input is a path or PIL.Image, but got {}.".format(type(image)))
|
||||
if images is not None: # add image features
|
||||
image_data = []
|
||||
for image in images:
|
||||
if not isinstance(image, (str, ImageObject)):
|
||||
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
multi_modal_data = {"image": image}
|
||||
image_data.append(image)
|
||||
|
||||
multi_modal_data = {"image": image_data}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
result_generator = self.model.generate(
|
||||
inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=self.lora_request,
|
||||
@@ -182,12 +196,12 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
final_output = None
|
||||
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
||||
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
|
||||
async for request_output in generator:
|
||||
final_output = request_output
|
||||
|
||||
@@ -210,12 +224,12 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generated_text = ""
|
||||
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
||||
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
|
||||
async for result in generator:
|
||||
delta_text = result.outputs[0].text[len(generated_text) :]
|
||||
generated_text = result.outputs[0].text
|
||||
|
||||
@@ -22,8 +22,8 @@ from . import launcher
|
||||
from .api.app import run_api
|
||||
from .chat.chat_model import run_chat
|
||||
from .eval.evaluator import run_eval
|
||||
from .extras import logging
|
||||
from .extras.env import VERSION, print_env
|
||||
from .extras.logging import get_logger
|
||||
from .extras.misc import get_device_count
|
||||
from .train.tuner import export_model, run_exp
|
||||
from .webui.interface import run_web_demo, run_web_ui
|
||||
@@ -47,7 +47,7 @@ USAGE = (
|
||||
WELCOME = (
|
||||
"-" * 58
|
||||
+ "\n"
|
||||
+ "| Welcome to LLaMA Factory, version {}".format(VERSION)
|
||||
+ f"| Welcome to LLaMA Factory, version {VERSION}"
|
||||
+ " " * (21 - len(VERSION))
|
||||
+ "|\n|"
|
||||
+ " " * 56
|
||||
@@ -56,7 +56,7 @@ WELCOME = (
|
||||
+ "-" * 58
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
@@ -86,25 +86,26 @@ def main():
|
||||
elif command == Command.EXPORT:
|
||||
export_model()
|
||||
elif command == Command.TRAIN:
|
||||
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
||||
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
||||
if force_torchrun or get_device_count() > 1:
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
||||
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
|
||||
process = subprocess.run(
|
||||
(
|
||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||
).format(
|
||||
nnodes=os.environ.get("NNODES", "1"),
|
||||
node_rank=os.environ.get("RANK", "0"),
|
||||
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
||||
)
|
||||
.format(
|
||||
nnodes=os.getenv("NNODES", "1"),
|
||||
node_rank=os.getenv("NODE_RANK", "0"),
|
||||
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
file_name=launcher.__file__,
|
||||
args=" ".join(sys.argv[1:]),
|
||||
),
|
||||
shell=True,
|
||||
)
|
||||
.split()
|
||||
)
|
||||
sys.exit(process.returncode)
|
||||
else:
|
||||
@@ -118,4 +119,4 @@ def main():
|
||||
elif command == Command.HELP:
|
||||
print(USAGE)
|
||||
else:
|
||||
raise NotImplementedError("Unknown command: {}.".format(command))
|
||||
raise NotImplementedError(f"Unknown command: {command}.")
|
||||
|
||||
@@ -16,7 +16,7 @@ import os
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
|
||||
|
||||
@@ -29,45 +29,51 @@ if TYPE_CHECKING:
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_images(
|
||||
images: Sequence["ImageInput"],
|
||||
images: Union["ImageInput", Sequence["ImageInput"]],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
) -> Optional[List["ImageInput"]]:
|
||||
r"""
|
||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||
"""
|
||||
if len(images) == 0:
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
elif len(images) == 0:
|
||||
return None
|
||||
else:
|
||||
images = images[:]
|
||||
|
||||
images = images[:]
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for i in range(len(images)):
|
||||
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])):
|
||||
images[i] = os.path.join(data_args.dataset_dir, images[i])
|
||||
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
|
||||
images[i] = os.path.join(data_args.image_dir, images[i])
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def _convert_videos(
|
||||
videos: Sequence["VideoInput"],
|
||||
videos: Union["VideoInput", Sequence["VideoInput"]],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
) -> Optional[List["VideoInput"]]:
|
||||
r"""
|
||||
Optionally concatenates video path to dataset dir when loading from local disk.
|
||||
"""
|
||||
if len(videos) == 0:
|
||||
if not isinstance(videos, list):
|
||||
videos = [videos]
|
||||
elif len(videos) == 0:
|
||||
return None
|
||||
else:
|
||||
videos = videos[:]
|
||||
|
||||
videos = videos[:]
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for i in range(len(videos)):
|
||||
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])):
|
||||
videos[i] = os.path.join(data_args.dataset_dir, videos[i])
|
||||
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
|
||||
videos[i] = os.path.join(data_args.image_dir, videos[i])
|
||||
|
||||
return videos
|
||||
|
||||
@@ -161,7 +167,7 @@ def convert_sharegpt(
|
||||
broken_data = False
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||
logger.warning("Invalid role tag in {}.".format(messages))
|
||||
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
||||
broken_data = True
|
||||
|
||||
aligned_messages.append(
|
||||
@@ -171,7 +177,7 @@ def convert_sharegpt(
|
||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||
):
|
||||
logger.warning("Invalid message count in {}.".format(messages))
|
||||
logger.warning_rank0(f"Invalid message count in {messages}.")
|
||||
broken_data = True
|
||||
|
||||
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||
@@ -192,7 +198,7 @@ def convert_sharegpt(
|
||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||
):
|
||||
logger.warning("Invalid role tag in {}.".format([chosen, rejected]))
|
||||
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
||||
broken_data = True
|
||||
|
||||
prompt = aligned_messages
|
||||
@@ -205,7 +211,7 @@ def convert_sharegpt(
|
||||
response = aligned_messages[-1:]
|
||||
|
||||
if broken_data:
|
||||
logger.warning("Skipping this abnormal example.")
|
||||
logger.warning_rank0("Skipping this abnormal example.")
|
||||
prompt, response = [], []
|
||||
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
|
||||
@@ -79,7 +79,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
processor: Optional["ProcessorMixin"] = None
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
|
||||
for feature in features:
|
||||
images = feature.pop("images", None) or []
|
||||
videos = feature.pop("videos", None) or []
|
||||
@@ -87,10 +87,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
batch_videos.extend(videos)
|
||||
batch_imglens.append(len(images))
|
||||
batch_vidlens.append(len(videos))
|
||||
batch_seqlens.append(len(feature["input_ids"]))
|
||||
batch_input_ids.append(feature["input_ids"])
|
||||
|
||||
mm_inputs = self.template.mm_plugin.get_mm_inputs(
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
|
||||
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
|
||||
)
|
||||
if "token_type_ids" in mm_inputs:
|
||||
token_type_ids = mm_inputs.pop("token_type_ids")
|
||||
@@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
features.update(mm_inputs)
|
||||
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
||||
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@@ -137,9 +140,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
for key in ("chosen", "rejected"):
|
||||
for feature in features:
|
||||
target_feature = {
|
||||
"input_ids": feature["{}_input_ids".format(key)],
|
||||
"attention_mask": feature["{}_attention_mask".format(key)],
|
||||
"labels": feature["{}_labels".format(key)],
|
||||
"input_ids": feature[f"{key}_input_ids"],
|
||||
"attention_mask": feature[f"{key}_attention_mask"],
|
||||
"labels": feature[f"{key}_labels"],
|
||||
"images": feature["images"],
|
||||
"videos": feature["videos"],
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
|
||||
|
||||
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||
@@ -56,12 +56,12 @@ def merge_dataset(
|
||||
return all_datasets[0]
|
||||
elif data_args.mix_strategy == "concat":
|
||||
if data_args.streaming:
|
||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
|
||||
|
||||
return concatenate_datasets(all_datasets)
|
||||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
@@ -70,7 +70,7 @@ def merge_dataset(
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
|
||||
raise ValueError(f"Unknown mixing strategy: {data_args.mix_strategy}.")
|
||||
|
||||
|
||||
def split_dataset(
|
||||
|
||||
@@ -83,14 +83,14 @@ class StringFormatter(Formatter):
|
||||
if isinstance(slot, str):
|
||||
for name, value in kwargs.items():
|
||||
if not isinstance(value, str):
|
||||
raise RuntimeError("Expected a string, got {}".format(value))
|
||||
raise RuntimeError(f"Expected a string, got {value}")
|
||||
|
||||
slot = slot.replace("{{" + name + "}}", value, 1)
|
||||
elements.append(slot)
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||
|
||||
return elements
|
||||
|
||||
@@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
|
||||
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
functions = []
|
||||
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
||||
|
||||
elements = []
|
||||
for name, arguments in functions:
|
||||
@@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||
|
||||
return elements
|
||||
|
||||
@@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
|
||||
tools = json.loads(content)
|
||||
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
||||
except json.JSONDecodeError:
|
||||
return [""]
|
||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
|
||||
|
||||
@override
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
|
||||
@@ -20,8 +20,8 @@ import numpy as np
|
||||
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import FILEEXT2TYPE
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import has_tokenized_data
|
||||
from .aligner import align_dataset
|
||||
from .data_utils import merge_dataset, split_dataset
|
||||
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||
from .template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _load_single_dataset(
|
||||
@@ -51,9 +51,9 @@ def _load_single_dataset(
|
||||
r"""
|
||||
Loads a single dataset and aligns it to the standard format.
|
||||
"""
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
logger.info_rank0(f"Loading dataset {dataset_attr}...")
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
||||
data_path = dataset_attr.dataset_name
|
||||
data_name = dataset_attr.subset
|
||||
data_dir = dataset_attr.folder
|
||||
@@ -69,25 +69,24 @@ def _load_single_dataset(
|
||||
if os.path.isdir(local_path): # is directory
|
||||
for file_name in os.listdir(local_path):
|
||||
data_files.append(os.path.join(local_path, file_name))
|
||||
if data_path is None:
|
||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
||||
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
||||
raise ValueError("File types should be identical.")
|
||||
elif os.path.isfile(local_path): # is file
|
||||
data_files.append(local_path)
|
||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File {} not found.".format(local_path))
|
||||
raise ValueError(f"File {local_path} not found.")
|
||||
|
||||
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
|
||||
if data_path is None:
|
||||
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
||||
|
||||
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
|
||||
raise ValueError("File types should be identical.")
|
||||
else:
|
||||
raise NotImplementedError("Unknown load type: {}.".format(dataset_attr.load_from))
|
||||
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
|
||||
|
||||
if dataset_attr.load_from == "ms_hub":
|
||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||
from modelscope import MsDataset
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||
from modelscope import MsDataset # type: ignore
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
|
||||
|
||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||
dataset = MsDataset.load(
|
||||
@@ -98,10 +97,27 @@ def _load_single_dataset(
|
||||
split=dataset_attr.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
use_streaming=data_args.streaming,
|
||||
)
|
||||
if isinstance(dataset, MsDataset):
|
||||
dataset = dataset.to_hf_dataset()
|
||||
|
||||
elif dataset_attr.load_from == "om_hub":
|
||||
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||
from openmind import OmDataset # type: ignore
|
||||
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
|
||||
|
||||
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
|
||||
dataset = OmDataset.load_dataset(
|
||||
path=data_path,
|
||||
name=data_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=dataset_attr.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.om_hub_token,
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
else:
|
||||
dataset = load_dataset(
|
||||
path=data_path,
|
||||
@@ -111,13 +127,10 @@ def _load_single_dataset(
|
||||
split=dataset_attr.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
streaming=data_args.streaming,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||
target_num = dataset_attr.num_samples
|
||||
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
||||
@@ -128,7 +141,7 @@ def _load_single_dataset(
|
||||
|
||||
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
||||
dataset = dataset.select(indexes)
|
||||
logger.info("Sampled {} examples from dataset {}.".format(dataset_attr.num_samples, dataset_attr))
|
||||
logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
|
||||
|
||||
if data_args.max_samples is not None: # truncate dataset
|
||||
max_samples = min(data_args.max_samples, len(dataset))
|
||||
@@ -224,9 +237,9 @@ def get_dataset(
|
||||
# Load tokenized dataset
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
||||
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
||||
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
||||
|
||||
dataset_module: Dict[str, "Dataset"] = {}
|
||||
if "train" in dataset_dict:
|
||||
@@ -277,8 +290,8 @@ def get_dataset(
|
||||
if data_args.tokenized_path is not None:
|
||||
if training_args.should_save:
|
||||
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
||||
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
||||
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
|
||||
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.image_utils import get_image_size, to_numpy_array
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.packages import is_pillow_available, is_pyav_available
|
||||
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
@@ -18,8 +21,15 @@ if is_pyav_available():
|
||||
import av
|
||||
|
||||
|
||||
if is_transformers_version_greater_than("4.45.0"):
|
||||
from transformers.models.mllama.processing_mllama import (
|
||||
convert_sparse_cross_attention_mask_to_dense,
|
||||
get_cross_attention_token_mask,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from av.stream import Stream
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
@@ -27,111 +37,10 @@ if TYPE_CHECKING:
|
||||
path: Optional[str]
|
||||
bytes: Optional[bytes]
|
||||
|
||||
ImageInput = Union[str, EncodedImage, ImageObject]
|
||||
ImageInput = Union[str, bytes, EncodedImage, ImageObject]
|
||||
VideoInput = str
|
||||
|
||||
|
||||
def _regularize_images(
|
||||
images: Sequence["ImageInput"],
|
||||
processor: "ProcessorMixin",
|
||||
max_resolution: Optional[int] = None,
|
||||
) -> List["ImageObject"]:
|
||||
r"""
|
||||
Regularizes images to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
if max_resolution is None:
|
||||
max_resolution: int = getattr(processor, "image_resolution", 512)
|
||||
|
||||
results = []
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
elif isinstance(image, dict):
|
||||
if image["bytes"] is not None:
|
||||
image = Image.open(BytesIO(image["bytes"]))
|
||||
else:
|
||||
image = Image.open(image["path"])
|
||||
|
||||
if not isinstance(image, ImageObject):
|
||||
raise ValueError("Expect input is a list of Images, but got {}.".format(type(image)))
|
||||
|
||||
if max(image.width, image.height) > max_resolution:
|
||||
factor = max_resolution / max(image.width, image.height)
|
||||
image = image.resize((int(image.width * factor), int(image.height * factor)), resample=Image.NEAREST)
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
results.append(image)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _regularize_videos(
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> List[List["ImageObject"]]:
|
||||
r"""
|
||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
video_resolution: int = getattr(processor, "video_resolution", 128)
|
||||
video_fps: float = getattr(processor, "video_fps", 1.0)
|
||||
video_maxlen: int = getattr(processor, "video_maxlen", 64)
|
||||
video_factor: int = getattr(processor, "video_factor", 1)
|
||||
results = []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
total_frames = video_stream.frames
|
||||
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
||||
sample_frames = min(video_maxlen, sample_frames) # reduce length <= maxlen
|
||||
sample_frames = round(sample_frames / video_factor) * video_factor # for qwen2_vl
|
||||
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
frames: List["ImageObject"] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
frames = _regularize_images(frames, processor, video_resolution)
|
||||
results.append(frames)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _get_mm_inputs(
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs.
|
||||
|
||||
Returns: (llava and paligemma)
|
||||
pixel_values: tensor with shape (B, C, H, W)
|
||||
|
||||
Returns: (qwen2-vl)
|
||||
pixel_values: tensor with shape (num_patches, patch_dim)
|
||||
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
||||
|
||||
It holds num_patches == torch.prod(image_grid_thw)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
input_dict = {"images": None} # default key
|
||||
if len(images) != 0:
|
||||
images = _regularize_images(images, processor)
|
||||
input_dict["images"] = images
|
||||
|
||||
if len(videos) != 0:
|
||||
videos = _regularize_videos(videos, processor)
|
||||
input_dict["videos"] = videos
|
||||
|
||||
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
|
||||
return image_processor(**input_dict, return_tensors="pt")
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def _get_paligemma_token_type_ids(
|
||||
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
|
||||
) -> List[List[int]]:
|
||||
@@ -159,12 +68,134 @@ class BasePlugin:
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
) -> None:
|
||||
r"""
|
||||
Validates if this model accepts the input modalities.
|
||||
"""
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError("This model does not support image input.")
|
||||
|
||||
if len(videos) != 0 and self.video_token is None:
|
||||
raise ValueError("This model does not support video input.")
|
||||
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
r"""
|
||||
Pre-processes a single image.
|
||||
"""
|
||||
image_resolution: int = kwargs.get("image_resolution")
|
||||
if (image.width * image.height) > image_resolution:
|
||||
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
|
||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||
image = image.resize((width, height), resample=Image.NEAREST)
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
||||
r"""
|
||||
Computes video sample frames according to fps.
|
||||
"""
|
||||
video_fps: float = kwargs.get("video_fps")
|
||||
video_maxlen: int = kwargs.get("video_maxlen")
|
||||
total_frames = video_stream.frames
|
||||
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||
return math.floor(sample_frames)
|
||||
|
||||
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
|
||||
r"""
|
||||
Regularizes images to avoid error. Including reading and pre-processing.
|
||||
"""
|
||||
results = []
|
||||
for image in images:
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
elif isinstance(image, bytes):
|
||||
image = Image.open(BytesIO(image))
|
||||
elif isinstance(image, dict):
|
||||
if image["bytes"] is not None:
|
||||
image = Image.open(BytesIO(image["bytes"]))
|
||||
else:
|
||||
image = Image.open(image["path"])
|
||||
|
||||
if not isinstance(image, ImageObject):
|
||||
raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
|
||||
|
||||
results.append(self._preprocess_image(image, **kwargs))
|
||||
|
||||
return results
|
||||
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||
r"""
|
||||
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||
"""
|
||||
results = []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
total_frames = video_stream.frames
|
||||
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
|
||||
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
frames: List["ImageObject"] = []
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
frames = self._regularize_images(frames, **kwargs)
|
||||
results.append(frames)
|
||||
|
||||
return results
|
||||
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs.
|
||||
|
||||
Returns: (llava and paligemma)
|
||||
pixel_values: tensor with shape (B, C, H, W)
|
||||
|
||||
Returns: (qwen2-vl)
|
||||
pixel_values: tensor with shape (num_patches, patch_dim)
|
||||
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
||||
|
||||
It holds num_patches == torch.prod(image_grid_thw)
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
|
||||
input_dict = {"images": None} # default key
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
||||
)
|
||||
input_dict["images"] = images
|
||||
|
||||
if len(videos) != 0:
|
||||
videos = self._regularize_videos(
|
||||
videos,
|
||||
image_resolution=getattr(processor, "video_resolution", 128 * 128),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||
)
|
||||
input_dict["videos"] = videos
|
||||
|
||||
mm_inputs = {}
|
||||
if image_processor != video_processor:
|
||||
if input_dict.get("images") is not None:
|
||||
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
|
||||
if input_dict.get("videos") is not None:
|
||||
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
|
||||
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
|
||||
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@@ -199,11 +230,19 @@ class BasePlugin:
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
r"""
|
||||
Builds batched multimodal inputs for VLMs.
|
||||
|
||||
Arguments:
|
||||
images: a list of image inputs, shape (num_images,)
|
||||
videos: a list of video inputs, shape (num_videos,)
|
||||
imglens: number of images in each sample, shape (batch_size,)
|
||||
vidlens: number of videos in each sample, shape (batch_size,)
|
||||
batch_ids: input ids of samples, shape (batch_size, seq_len)
|
||||
processor: a processor for pre-processing images and videos
|
||||
"""
|
||||
self._validate_input(images, videos)
|
||||
return {}
|
||||
@@ -226,12 +265,12 @@ class LlavaPlugin(BasePlugin):
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@@ -242,11 +281,129 @@ class LlavaPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return _get_mm_inputs(images, videos, processor)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class LlavaNextPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if "image_sizes" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
|
||||
if "pixel_values" in mm_inputs:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class LlavaNextVideoPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if "pixel_values" in mm_inputs:
|
||||
image_sizes = iter(mm_inputs["image_sizes"])
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_size = next(image_sizes)
|
||||
orig_height, orig_width = image_size
|
||||
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
|
||||
message["content"] = content.replace("{{image}}", self.image_token)
|
||||
|
||||
if "pixel_values_videos" in mm_inputs:
|
||||
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(pixel_values_video[0])
|
||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||
|
||||
message["content"] = content.replace("{{video}}", self.video_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class PaliGemmaPlugin(BasePlugin):
|
||||
@@ -270,7 +427,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
message["content"] = content.replace("{{image}}", "")
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@@ -301,16 +458,102 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = _get_mm_inputs(images, videos, processor)
|
||||
seqlens = [len(input_ids) for input_ids in batch_ids]
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class PixtralPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
patch_size = getattr(processor, "patch_size")
|
||||
image_token = getattr(processor, "image_token")
|
||||
image_break_token = getattr(processor, "image_break_token")
|
||||
image_end_token = getattr(processor, "image_end_token")
|
||||
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
image_input_sizes = mm_inputs.get("image_sizes", None)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if image_input_sizes is None:
|
||||
raise ValueError("Cannot get image input sizes.")
|
||||
|
||||
image_size = image_input_sizes[0][num_image_tokens]
|
||||
height, width = image_size
|
||||
num_height_tokens = height // patch_size
|
||||
num_width_tokens = width // patch_size
|
||||
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
||||
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
||||
replace_tokens[-1] = image_end_token
|
||||
replace_str = "".join(replace_tokens)
|
||||
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||
num_image_tokens += 1
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if mm_inputs.get("pixel_values"):
|
||||
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
|
||||
|
||||
mm_inputs.pop("image_sizes", None)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
@override
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
image = super()._preprocess_image(image, **kwargs)
|
||||
if min(image.width, image.height) < 28:
|
||||
width, height = max(image.width, 28), max(image.height, 28)
|
||||
image = image.resize((width, height), resample=Image.NEAREST)
|
||||
|
||||
if image.width / image.height > 200:
|
||||
width, height = image.height * 180, image.height
|
||||
image = image.resize((width, height), resample=Image.NEAREST)
|
||||
|
||||
if image.height / image.width > 200:
|
||||
width, height = image.width, image.width * 180
|
||||
image = image.resize((width, height), resample=Image.NEAREST)
|
||||
|
||||
return image
|
||||
|
||||
@override
|
||||
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
||||
sample_frames = super()._get_video_sample_frames(video_stream, **kwargs)
|
||||
sample_frames = sample_frames // 2 * 2
|
||||
return sample_frames
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
@@ -322,7 +565,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
self._validate_input(images, videos)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
mm_inputs = _get_mm_inputs(images, videos, processor)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
|
||||
@@ -332,7 +575,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if num_image_tokens >= len(image_grid_thw):
|
||||
raise ValueError("`len(images)` is less than the number of {} tokens.".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER,
|
||||
@@ -345,7 +588,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
if num_video_tokens >= len(video_grid_thw):
|
||||
raise ValueError("`len(videos)` is less than the number of {} tokens.".format(VIDEO_PLACEHOLDER))
|
||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER,
|
||||
@@ -359,10 +602,10 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError("The number of videos does not match the number of {} tokens".format(VIDEO_PLACEHOLDER))
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@@ -373,18 +616,162 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
seqlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return _get_mm_inputs(images, videos, processor)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class VideoLlavaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
num_frames = 0
|
||||
has_images = "pixel_values_images" in mm_inputs
|
||||
has_videos = "pixel_values_videos" in mm_inputs
|
||||
if has_images or has_videos:
|
||||
if has_images:
|
||||
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||
num_frames = 1
|
||||
|
||||
if has_videos:
|
||||
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||
height, width = get_image_size(pixel_values_video[0])
|
||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||
|
||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||
video_seqlen = image_seqlen * num_frames
|
||||
if getattr(processor, "vision_feature_select_strategy") == "default":
|
||||
image_seqlen -= 1
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
num_image_tokens += 1
|
||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
num_video_tokens += 1
|
||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||
|
||||
content = content.replace("{{image}}", self.image_token)
|
||||
message["content"] = content.replace("{{video}}", self.video_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
if len(videos) != num_video_tokens:
|
||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class MllamaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
||||
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
||||
|
||||
Returns:
|
||||
pixel_values: tensor with shape
|
||||
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
||||
For example, (2, 1, 4, 3, 560, 560).
|
||||
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
||||
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
||||
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
|
||||
return image_processor([[image] for image in images], return_tensors="pt")
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
if len(images) != len(batch_ids):
|
||||
raise ValueError("Mllama only supports one image per sample.")
|
||||
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
num_tiles = mm_inputs.pop("num_tiles")
|
||||
image_token_id = getattr(processor, "image_token_id")
|
||||
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
||||
cross_attention_token_mask = [
|
||||
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
||||
]
|
||||
mm_inputs["cross_attention_mask"] = convert_sparse_cross_attention_mask_to_dense(
|
||||
cross_attention_token_mask,
|
||||
num_tiles=num_tiles,
|
||||
max_num_tiles=max_image_tiles,
|
||||
length=max(len(input_ids) for input_ids in batch_ids),
|
||||
)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"llava": LlavaPlugin,
|
||||
"llava_next": LlavaNextPlugin,
|
||||
"llava_next_video": LlavaNextVideoPlugin,
|
||||
"paligemma": PaliGemmaPlugin,
|
||||
"pixtral": PixtralPlugin,
|
||||
"qwen2_vl": Qwen2vlPlugin,
|
||||
"video_llava": VideoLlavaPlugin,
|
||||
"mllama": MllamaPlugin,
|
||||
}
|
||||
|
||||
|
||||
@@ -395,6 +782,6 @@ def get_mm_plugin(
|
||||
) -> "BasePlugin":
|
||||
plugin_class = PLUGINS.get(name, None)
|
||||
if plugin_class is None:
|
||||
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
||||
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
||||
|
||||
return plugin_class(image_token, video_token)
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
|
||||
from transformers.utils import cached_file
|
||||
|
||||
from ..extras.constants import DATA_CONFIG
|
||||
from ..extras.misc import use_modelscope
|
||||
from ..extras.misc import use_modelscope, use_openmind
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -30,7 +30,7 @@ class DatasetAttr:
|
||||
"""
|
||||
|
||||
# basic configs
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
load_from: Literal["hf_hub", "ms_hub", "om_hub", "script", "file"]
|
||||
dataset_name: str
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
ranking: bool = False
|
||||
@@ -87,31 +87,39 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
|
||||
config_path = os.path.join(dataset_dir, DATA_CONFIG)
|
||||
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
with open(config_path) as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception as err:
|
||||
if len(dataset_names) != 0:
|
||||
raise ValueError("Cannot open {} due to {}.".format(config_path, str(err)))
|
||||
raise ValueError(f"Cannot open {config_path} due to {str(err)}.")
|
||||
|
||||
dataset_info = None
|
||||
|
||||
dataset_list: List["DatasetAttr"] = []
|
||||
for name in dataset_names:
|
||||
if dataset_info is None: # dataset_dir is ONLINE
|
||||
load_from = "ms_hub" if use_modelscope() else "hf_hub"
|
||||
if use_modelscope():
|
||||
load_from = "ms_hub"
|
||||
elif use_openmind():
|
||||
load_from = "om_hub"
|
||||
else:
|
||||
load_from = "hf_hub"
|
||||
dataset_attr = DatasetAttr(load_from, dataset_name=name)
|
||||
dataset_list.append(dataset_attr)
|
||||
continue
|
||||
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||
raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}.")
|
||||
|
||||
has_hf_url = "hf_hub_url" in dataset_info[name]
|
||||
has_ms_url = "ms_hub_url" in dataset_info[name]
|
||||
has_om_url = "om_hub_url" in dataset_info[name]
|
||||
|
||||
if has_hf_url or has_ms_url:
|
||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
||||
if has_hf_url or has_ms_url or has_om_url:
|
||||
if has_ms_url and (use_modelscope() or not has_hf_url):
|
||||
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
||||
elif has_om_url and (use_openmind() or not has_hf_url):
|
||||
dataset_attr = DatasetAttr("om_hub", dataset_name=dataset_info[name]["om_hub_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
elif "script_url" in dataset_info[name]:
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_feedback_example(
|
||||
@@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
||||
@@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
|
||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||
if desirable_num == 0 or undesirable_num == 0:
|
||||
logger.warning("Your dataset only has one preference type.")
|
||||
logger.warning_rank0("Your dataset only has one preference type.")
|
||||
|
||||
return model_inputs
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_pairwise_example(
|
||||
@@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||
@@ -110,8 +112,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
|
||||
print("chosen_input_ids:\n{}".format(example["chosen_input_ids"]))
|
||||
print("chosen_inputs:\n{}".format(tokenizer.decode(example["chosen_input_ids"], skip_special_tokens=False)))
|
||||
print("chosen_label_ids:\n{}".format(example["chosen_labels"]))
|
||||
print("chosen_labels:\n{}".format(tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)))
|
||||
print(f"chosen_labels:\n{tokenizer.decode(valid_chosen_labels, skip_special_tokens=False)}")
|
||||
print("rejected_input_ids:\n{}".format(example["rejected_input_ids"]))
|
||||
print("rejected_inputs:\n{}".format(tokenizer.decode(example["rejected_input_ids"], skip_special_tokens=False)))
|
||||
print("rejected_label_ids:\n{}".format(example["rejected_labels"]))
|
||||
print("rejected_labels:\n{}".format(tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)))
|
||||
print(f"rejected_labels:\n{tokenizer.decode(valid_rejected_labels, skip_special_tokens=False)}")
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import greedy_knapsack, infer_seqlen
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_supervised_example(
|
||||
@@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
@@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
|
||||
length2indexes = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
@@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
|
||||
)
|
||||
length = len(input_ids)
|
||||
if length > data_args.cutoff_len:
|
||||
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len))
|
||||
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
|
||||
else:
|
||||
lengths.append(length)
|
||||
length2indexes[length].append(valid_num)
|
||||
@@ -212,4 +216,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)))
|
||||
print(f"labels:\n{tokenizer.decode(valid_labels, skip_special_tokens=False)}")
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras import logging
|
||||
from ..data_utils import Role
|
||||
from .processor_utils import infer_seqlen
|
||||
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from ..template import Template
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _encode_unsupervised_example(
|
||||
@@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset(
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["_prompt"])):
|
||||
if len(examples["_prompt"][i]) % 2 != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
||||
logger.warning_rank0(
|
||||
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||
)
|
||||
continue
|
||||
|
||||
input_ids, labels = _encode_unsupervised_example(
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from transformers.utils.versions import require_version
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
from .data_utils import Role
|
||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
from .mm_plugin import get_mm_plugin
|
||||
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
|
||||
from .mm_plugin import BasePlugin
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -49,6 +49,7 @@ class Template:
|
||||
stop_words: List[str]
|
||||
efficient_eos: bool
|
||||
replace_eos: bool
|
||||
replace_jinja_template: bool
|
||||
mm_plugin: "BasePlugin"
|
||||
|
||||
def encode_oneturn(
|
||||
@@ -146,7 +147,7 @@ class Template:
|
||||
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
||||
token_ids += [tokenizer.eos_token_id]
|
||||
else:
|
||||
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
||||
raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
|
||||
|
||||
return token_ids
|
||||
|
||||
@@ -214,6 +215,7 @@ def _register_template(
|
||||
stop_words: Sequence[str] = [],
|
||||
efficient_eos: bool = False,
|
||||
replace_eos: bool = False,
|
||||
replace_jinja_template: bool = True,
|
||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||
) -> None:
|
||||
r"""
|
||||
@@ -263,6 +265,7 @@ def _register_template(
|
||||
stop_words=stop_words,
|
||||
efficient_eos=efficient_eos,
|
||||
replace_eos=replace_eos,
|
||||
replace_jinja_template=replace_jinja_template,
|
||||
mm_plugin=mm_plugin,
|
||||
)
|
||||
|
||||
@@ -272,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
|
||||
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||
|
||||
if is_added:
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
|
||||
else:
|
||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
|
||||
|
||||
if num_added_tokens > 0:
|
||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
|
||||
def _jinja_escape(content: str) -> str:
|
||||
@@ -353,23 +356,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
r"""
|
||||
Gets chat template and fixes the tokenizer.
|
||||
"""
|
||||
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
|
||||
require_version(
|
||||
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||
)
|
||||
|
||||
if data_args.template is None:
|
||||
template = TEMPLATES["empty"] # placeholder
|
||||
else:
|
||||
template = TEMPLATES.get(data_args.template, None)
|
||||
if template is None:
|
||||
raise ValueError("Template {} does not exist.".format(data_args.template))
|
||||
raise ValueError(f"Template {data_args.template} does not exist.")
|
||||
|
||||
if template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
|
||||
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
if data_args.tool_format is not None:
|
||||
logger.info("Using tool format: {}.".format(data_args.tool_format))
|
||||
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
|
||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
|
||||
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
||||
@@ -387,20 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
|
||||
|
||||
if stop_words:
|
||||
num_added_tokens = tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
||||
)
|
||||
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
||||
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
|
||||
if num_added_tokens > 0:
|
||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||
|
||||
try:
|
||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
||||
except ValueError:
|
||||
logger.info("Cannot add this chat template to tokenizer.")
|
||||
if tokenizer.chat_template is None or template.replace_jinja_template:
|
||||
try:
|
||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
||||
except ValueError as e:
|
||||
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
|
||||
|
||||
return template
|
||||
|
||||
@@ -639,6 +641,14 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="exaone",
|
||||
format_user=StringFormatter(slots=["[|user|]{{content}}\n[|assistant|]"]),
|
||||
format_system=StringFormatter(slots=["[|system|]{{content}}[|endofturn|]\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="falcon",
|
||||
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
||||
@@ -663,6 +673,7 @@ _register_template(
|
||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
efficient_eos=True,
|
||||
replace_jinja_template=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -680,6 +691,14 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="index",
|
||||
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
|
||||
format_system=StringFormatter(slots=["<unk>{{content}}"]),
|
||||
efficient_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
||||
@@ -739,6 +758,34 @@ _register_template(
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="mllama",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
|
||||
)
|
||||
|
||||
|
||||
@@ -753,6 +800,107 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_llama3",
|
||||
format_user=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot_id|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_qwen",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_video",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_video_mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="llava_next_video_yi",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="mistral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
@@ -790,6 +938,19 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="opencoder",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="You are OpenCoder, created by OpenCoder Team.",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="orion",
|
||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||
@@ -821,6 +982,25 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="phi_small",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"<|endoftext|>"}]),
|
||||
stop_words=["<|end|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="pixtral",
|
||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
mm_plugin=get_mm_plugin(name="pixtral", image_token="[IMG]"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="qwen",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
@@ -830,6 +1010,7 @@ _register_template(
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -842,6 +1023,7 @@ _register_template(
|
||||
default_system="You are a helpful assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
replace_jinja_template=False,
|
||||
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||
)
|
||||
|
||||
@@ -897,6 +1079,17 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="video_llava",
|
||||
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
||||
default_system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
mm_plugin=get_mm_plugin(name="video_llava", image_token="<image>", video_token="<video>"),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="xuanyuan",
|
||||
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
||||
|
||||
@@ -177,6 +177,6 @@ TOOLS = {
|
||||
def get_tool_utils(name: str) -> "ToolUtils":
|
||||
tool_utils = TOOLS.get(name, None)
|
||||
if tool_utils is None:
|
||||
raise ValueError("Tool utils `{}` not found.".format(name))
|
||||
raise ValueError(f"Tool utils `{name}` not found.")
|
||||
|
||||
return tool_utils
|
||||
|
||||
@@ -87,7 +87,7 @@ class Evaluator:
|
||||
token=self.model_args.hf_hub_token,
|
||||
)
|
||||
|
||||
with open(mapping, "r", encoding="utf-8") as f:
|
||||
with open(mapping, encoding="utf-8") as f:
|
||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||
|
||||
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
||||
@@ -139,7 +139,7 @@ class Evaluator:
|
||||
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
|
||||
score_info = "\n".join(
|
||||
[
|
||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
|
||||
for category_name, category_correct in category_corrects.items()
|
||||
if len(category_correct)
|
||||
]
|
||||
|
||||
@@ -61,7 +61,7 @@ def _register_eval_template(name: str, system: str, choice: str, answer: str) ->
|
||||
|
||||
def get_eval_template(name: str) -> "EvalTemplate":
|
||||
eval_template = eval_templates.get(name, None)
|
||||
assert eval_template is not None, "Template {} does not exist.".format(name)
|
||||
assert eval_template is not None, f"Template {name} does not exist."
|
||||
return eval_template
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ import trl
|
||||
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
||||
|
||||
|
||||
VERSION = "0.9.0"
|
||||
VERSION = "0.9.1"
|
||||
|
||||
|
||||
def print_env() -> None:
|
||||
@@ -72,4 +72,4 @@ def print_env() -> None:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
|
||||
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
|
||||
|
||||
@@ -20,6 +20,7 @@ import os
|
||||
import sys
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
from .constants import RUNNING_LOG
|
||||
@@ -37,12 +38,11 @@ class LoggerHandler(logging.Handler):
|
||||
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
super().__init__()
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
||||
self._formatter = logging.Formatter(
|
||||
fmt="[%(levelname)s|%(asctime)s] %(filename)s:%(lineno)s >> %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
self.setLevel(logging.INFO)
|
||||
self.setFormatter(formatter)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
||||
if os.path.exists(self.running_log):
|
||||
@@ -58,7 +58,7 @@ class LoggerHandler(logging.Handler):
|
||||
if record.name == "httpx":
|
||||
return
|
||||
|
||||
log_entry = self.format(record)
|
||||
log_entry = self._formatter.format(record)
|
||||
self.thread_pool.submit(self._write_log, log_entry)
|
||||
|
||||
def close(self) -> None:
|
||||
@@ -66,6 +66,21 @@ class LoggerHandler(logging.Handler):
|
||||
return super().close()
|
||||
|
||||
|
||||
class _Logger(logging.Logger):
|
||||
r"""
|
||||
A logger that supports info_rank0 and warning_once.
|
||||
"""
|
||||
|
||||
def info_rank0(self, *args, **kwargs) -> None:
|
||||
self.info(*args, **kwargs)
|
||||
|
||||
def warning_rank0(self, *args, **kwargs) -> None:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
def warning_once(self, *args, **kwargs) -> None:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_default_logging_level() -> "logging._Level":
|
||||
r"""
|
||||
Returns the default logging level.
|
||||
@@ -75,7 +90,7 @@ def _get_default_logging_level() -> "logging._Level":
|
||||
if env_level_str.upper() in logging._nameToLevel:
|
||||
return logging._nameToLevel[env_level_str.upper()]
|
||||
else:
|
||||
raise ValueError("Unknown logging level: {}.".format(env_level_str))
|
||||
raise ValueError(f"Unknown logging level: {env_level_str}.")
|
||||
|
||||
return _default_log_level
|
||||
|
||||
@@ -84,7 +99,7 @@ def _get_library_name() -> str:
|
||||
return __name__.split(".")[0]
|
||||
|
||||
|
||||
def _get_library_root_logger() -> "logging.Logger":
|
||||
def _get_library_root_logger() -> "_Logger":
|
||||
return logging.getLogger(_get_library_name())
|
||||
|
||||
|
||||
@@ -95,12 +110,12 @@ def _configure_library_root_logger() -> None:
|
||||
global _default_handler
|
||||
|
||||
with _thread_lock:
|
||||
if _default_handler:
|
||||
if _default_handler: # already configured
|
||||
return
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
_default_handler = logging.StreamHandler(sys.stdout)
|
||||
_default_handler.setFormatter(formatter)
|
||||
@@ -110,7 +125,7 @@ def _configure_library_root_logger() -> None:
|
||||
library_root_logger.propagate = False
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> "logging.Logger":
|
||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
r"""
|
||||
Returns a logger with the specified name. It it not supposed to be accessed externally.
|
||||
"""
|
||||
@@ -119,3 +134,40 @@ def get_logger(name: Optional[str] = None) -> "logging.Logger":
|
||||
|
||||
_configure_library_root_logger()
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def add_handler(handler: "logging.Handler") -> None:
|
||||
r"""
|
||||
Adds a handler to the root logger.
|
||||
"""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().addHandler(handler)
|
||||
|
||||
|
||||
def remove_handler(handler: logging.Handler) -> None:
|
||||
r"""
|
||||
Removes a handler to the root logger.
|
||||
"""
|
||||
_configure_library_root_logger()
|
||||
_get_library_root_logger().removeHandler(handler)
|
||||
|
||||
|
||||
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
self.info(*args, **kwargs)
|
||||
|
||||
|
||||
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
logging.Logger.info_rank0 = info_rank0
|
||||
logging.Logger.warning_rank0 = warning_rank0
|
||||
logging.Logger.warning_once = warning_once
|
||||
|
||||
@@ -20,6 +20,7 @@ import os
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers.dynamic_module_utils
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
from transformers.dynamic_module_utils import get_relative_imports
|
||||
@@ -32,7 +33,7 @@ from transformers.utils import (
|
||||
)
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from .logging import get_logger
|
||||
from . import logging
|
||||
|
||||
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
@@ -48,7 +49,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
@@ -76,12 +77,12 @@ def check_dependencies() -> None:
|
||||
r"""
|
||||
Checks the version of the required packages.
|
||||
"""
|
||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
else:
|
||||
require_version("transformers>=4.41.2,<=4.45.0", "To fix: pip install transformers>=4.41.2,<=4.45.0")
|
||||
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
|
||||
require_version("accelerate>=0.30.1,<=0.33.0", "To fix: pip install accelerate>=0.30.1,<=0.33.0")
|
||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
|
||||
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
|
||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||
|
||||
@@ -231,18 +232,43 @@ def torch_gc() -> None:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
|
||||
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
||||
def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||
if (not use_modelscope() and not use_openmind()) or os.path.exists(model_args.model_name_or_path):
|
||||
return model_args.model_name_or_path
|
||||
|
||||
try:
|
||||
from modelscope import snapshot_download
|
||||
if use_modelscope():
|
||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
return snapshot_download(
|
||||
model_args.model_name_or_path,
|
||||
revision=revision,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
if use_openmind():
|
||||
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||
from openmind.utils.hub import snapshot_download # type: ignore
|
||||
|
||||
return snapshot_download(
|
||||
model_args.model_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
cache_dir=model_args.cache_dir,
|
||||
)
|
||||
|
||||
|
||||
def use_modelscope() -> bool:
|
||||
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
|
||||
|
||||
|
||||
def use_openmind() -> bool:
|
||||
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
||||
|
||||
|
||||
def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int:
|
||||
r"""
|
||||
calculate effective tokens.
|
||||
"""
|
||||
result = effective_token_num * epoch / train_runtime
|
||||
return result / dist.get_world_size() if dist.is_initialized() else result
|
||||
|
||||
@@ -75,8 +75,13 @@ def is_starlette_available():
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_transformers_version_greater_than_4_43():
|
||||
return _get_package_version("transformers") >= version.parse("4.43.0")
|
||||
def is_transformers_version_greater_than(content: str):
|
||||
return _get_package_version("transformers") >= version.parse(content)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_transformers_version_equal_to_4_46():
|
||||
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import Any, Dict, List
|
||||
|
||||
from transformers.trainer import TRAINER_STATE_NAME
|
||||
|
||||
from .logging import get_logger
|
||||
from . import logging
|
||||
from .packages import is_matplotlib_available
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ if is_matplotlib_available():
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def smooth(scalars: List[float]) -> List[float]:
|
||||
@@ -75,7 +75,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
Plots loss curves and saves the image.
|
||||
"""
|
||||
plt.switch_backend("agg")
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
||||
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for key in keys:
|
||||
@@ -86,13 +86,13 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
metrics.append(data["log_history"][i][key])
|
||||
|
||||
if len(metrics) == 0:
|
||||
logger.warning(f"No metric {key} to plot.")
|
||||
logger.warning_rank0(f"No metric {key} to plot.")
|
||||
continue
|
||||
|
||||
plt.figure()
|
||||
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
|
||||
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
|
||||
plt.title("training {} of {}".format(key, save_dictionary))
|
||||
plt.title(f"training {key} of {save_dictionary}")
|
||||
plt.xlabel("step")
|
||||
plt.ylabel(key)
|
||||
plt.legend()
|
||||
|
||||
@@ -41,8 +41,12 @@ class DataArguments:
|
||||
default="data",
|
||||
metadata={"help": "Path to the folder containing the datasets."},
|
||||
)
|
||||
image_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=1024,
|
||||
default=2048,
|
||||
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||
)
|
||||
train_on_prompt: bool = field(
|
||||
@@ -111,7 +115,13 @@ class DataArguments:
|
||||
)
|
||||
tokenized_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the tokenized datasets."},
|
||||
metadata={
|
||||
"help": (
|
||||
"Path to save or load the tokenized datasets. "
|
||||
"If tokenized_path not exists, it will save the tokenized datasets. "
|
||||
"If tokenized_path exists, it will load the tokenized datasets."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -123,6 +133,9 @@ class DataArguments:
|
||||
self.dataset = split_arg(self.dataset)
|
||||
self.eval_dataset = split_arg(self.eval_dataset)
|
||||
|
||||
if self.image_dir is None:
|
||||
self.image_dir = self.dataset_dir
|
||||
|
||||
if self.dataset is None and self.val_size > 1e-6:
|
||||
raise ValueError("Cannot specify `val_size` if `dataset` is None.")
|
||||
|
||||
|
||||
@@ -346,6 +346,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to save the training loss curves."},
|
||||
)
|
||||
include_effective_tokens_per_second: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to compute effective tokens per second."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
def split_arg(arg):
|
||||
|
||||
@@ -15,10 +15,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
import json
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.training_args import _convert_str_dict
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
@@ -57,12 +59,12 @@ class ProcessorArguments:
|
||||
"""
|
||||
|
||||
image_resolution: int = field(
|
||||
default=512,
|
||||
metadata={"help": "Keeps the height or width of image below this resolution."},
|
||||
default=512 * 512,
|
||||
metadata={"help": "Keeps the number of pixels of image below this resolution."},
|
||||
)
|
||||
video_resolution: int = field(
|
||||
default=128,
|
||||
metadata={"help": "Keeps the height or width of video below this resolution."},
|
||||
default=128 * 128,
|
||||
metadata={"help": "Keeps the number of pixels of video below this resolution."},
|
||||
)
|
||||
video_fps: float = field(
|
||||
default=2.0,
|
||||
@@ -125,7 +127,7 @@ class VllmArguments:
|
||||
"""
|
||||
|
||||
vllm_maxlen: int = field(
|
||||
default=2048,
|
||||
default=4096,
|
||||
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
|
||||
)
|
||||
vllm_gpu_util: float = field(
|
||||
@@ -140,6 +142,10 @@ class VllmArguments:
|
||||
default=32,
|
||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||
)
|
||||
vllm_config: Optional[Union[dict, str]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -267,6 +273,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
||||
)
|
||||
om_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Modelers Hub."},
|
||||
)
|
||||
print_param_status: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||
@@ -308,20 +318,21 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
||||
if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
|
||||
raise ValueError("Quantization dataset is necessary for exporting.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
if isinstance(self.vllm_config, str) and self.vllm_config.startswith("{"):
|
||||
self.vllm_config = _convert_str_dict(json.loads(self.vllm_config))
|
||||
|
||||
@classmethod
|
||||
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
|
||||
arg_dict = old_arg.to_dict()
|
||||
arg_dict.update(**kwargs)
|
||||
for attr in fields(cls):
|
||||
if not attr.init:
|
||||
arg_dict.pop(attr.name)
|
||||
def copyfrom(cls, source: "Self", **kwargs) -> "Self":
|
||||
init_args, lazy_args = {}, {}
|
||||
for attr in fields(source):
|
||||
if attr.init:
|
||||
init_args[attr.name] = getattr(source, attr.name)
|
||||
else:
|
||||
lazy_args[attr.name] = getattr(source, attr.name)
|
||||
|
||||
new_arg = cls(**arg_dict)
|
||||
new_arg.compute_dtype = old_arg.compute_dtype
|
||||
new_arg.device_map = old_arg.device_map
|
||||
new_arg.model_max_length = old_arg.model_max_length
|
||||
new_arg.block_diag_attn = old_arg.block_diag_attn
|
||||
return new_arg
|
||||
init_args.update(kwargs)
|
||||
result = cls(**init_args)
|
||||
for name, value in lazy_args.items():
|
||||
setattr(result, name, value)
|
||||
|
||||
return result
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
@@ -29,8 +28,8 @@ from transformers.training_args import ParallelMode
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import CHECKPOINT_NAMES
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
@@ -39,7 +38,7 @@ from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
check_dependencies()
|
||||
@@ -57,7 +56,7 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
if len(sys.argv) == 2 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
@@ -67,14 +66,14 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
||||
|
||||
if unknown_args:
|
||||
print(parser.format_help())
|
||||
print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
|
||||
raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
return (*parsed_args,)
|
||||
|
||||
|
||||
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
def _set_transformers_logging() -> None:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
@@ -104,7 +103,7 @@ def _verify_model_args(
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
if data_args.template == "yi" and model_args.use_fast_tokenizer:
|
||||
logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||
model_args.use_fast_tokenizer = False
|
||||
|
||||
|
||||
@@ -123,7 +122,7 @@ def _check_extra_dependencies(
|
||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
require_version("vllm>=0.4.3,<=0.6.0", "To fix: pip install vllm>=0.4.3,<=0.6.0")
|
||||
require_version("vllm>=0.4.3,<0.6.4", "To fix: pip install vllm>=0.4.3,<0.6.4")
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
@@ -261,7 +260,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if data_args.neat_packing and not data_args.packing:
|
||||
logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
||||
logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
||||
data_args.packing = True
|
||||
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
@@ -274,22 +273,26 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
and model_args.resize_vocab
|
||||
and finetuning_args.additional_target is None
|
||||
):
|
||||
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
|
||||
logger.warning_rank0(
|
||||
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
|
||||
)
|
||||
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||
logger.warning("We recommend enable mixed precision training.")
|
||||
logger.warning_rank0("We recommend enable mixed precision training.")
|
||||
|
||||
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
|
||||
logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.")
|
||||
logger.warning_rank0(
|
||||
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
|
||||
)
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
||||
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
||||
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
|
||||
|
||||
# Post-process training arguments
|
||||
if (
|
||||
@@ -297,13 +300,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||
can_resume_from_checkpoint = False
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
logger.warning("Cannot resume from checkpoint in current stage.")
|
||||
logger.warning_rank0("Cannot resume from checkpoint in current stage.")
|
||||
training_args.resume_from_checkpoint = None
|
||||
else:
|
||||
can_resume_from_checkpoint = True
|
||||
@@ -323,15 +326,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args.resume_from_checkpoint = last_checkpoint
|
||||
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
|
||||
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
||||
logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
|
||||
logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
||||
|
||||
if (
|
||||
finetuning_args.stage in ["rm", "ppo"]
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
and training_args.resume_from_checkpoint is not None
|
||||
):
|
||||
logger.warning(
|
||||
logger.warning_rank0(
|
||||
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import is_fsdp_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras import logging
|
||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
from .model_utils.quantization import QuantizationMethod
|
||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _setup_full_tuning(
|
||||
@@ -45,7 +45,7 @@ def _setup_full_tuning(
|
||||
if not is_trainable:
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Full")
|
||||
logger.info_rank0("Fine-tuning method: Full")
|
||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||
for name, param in model.named_parameters():
|
||||
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||
@@ -64,7 +64,7 @@ def _setup_freeze_tuning(
|
||||
if not is_trainable:
|
||||
return
|
||||
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
logger.info_rank0("Fine-tuning method: Freeze")
|
||||
if hasattr(model.config, "text_config"): # composite models
|
||||
config = getattr(model.config, "text_config")
|
||||
else:
|
||||
@@ -133,7 +133,7 @@ def _setup_freeze_tuning(
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
|
||||
logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
|
||||
|
||||
|
||||
def _setup_lora_tuning(
|
||||
@@ -145,7 +145,7 @@ def _setup_lora_tuning(
|
||||
cast_trainable_params_to_fp32: bool,
|
||||
) -> "PeftModel":
|
||||
if is_trainable:
|
||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
|
||||
adapter_to_resume = None
|
||||
|
||||
@@ -182,7 +182,7 @@ def _setup_lora_tuning(
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if len(adapter_to_merge) > 0:
|
||||
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
|
||||
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
|
||||
|
||||
if adapter_to_resume is not None: # resume lora training
|
||||
if model_args.use_unsloth:
|
||||
@@ -190,7 +190,7 @@ def _setup_lora_tuning(
|
||||
else:
|
||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||
|
||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||
|
||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
@@ -219,7 +219,7 @@ def _setup_lora_tuning(
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
finetuning_args.additional_target = module_names
|
||||
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
@@ -236,11 +236,11 @@ def _setup_lora_tuning(
|
||||
else:
|
||||
if finetuning_args.pissa_init:
|
||||
if finetuning_args.pissa_iter == -1:
|
||||
logger.info("Using PiSSA initialization.")
|
||||
logger.info_rank0("Using PiSSA initialization.")
|
||||
peft_kwargs["init_lora_weights"] = "pissa"
|
||||
else:
|
||||
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
|
||||
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
|
||||
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
|
||||
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
@@ -284,11 +284,11 @@ def init_adapter(
|
||||
if not is_trainable:
|
||||
pass
|
||||
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||
else:
|
||||
logger.info("Upcasting trainable params to float32.")
|
||||
logger.info_rank0("Upcasting trainable params to float32.")
|
||||
cast_trainable_params_to_fp32 = True
|
||||
|
||||
if finetuning_args.finetuning_type == "full":
|
||||
@@ -300,6 +300,6 @@ def init_adapter(
|
||||
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
|
||||
raise NotImplementedError(f"Unknown finetuning type: {finetuning_args.finetuning_type}.")
|
||||
|
||||
return model
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user