76 Commits

Author SHA1 Message Date
hiyouga
f0bff18324 Update publish.yml
Former-commit-id: 60b0633e29c9e701aa3813bd1fdc0282bd07f7c8
2024-06-19 20:46:33 +08:00
hiyouga
b631bdc5b7 release v0.8.2
Former-commit-id: 3050bbe51d46acd8473275d2713fc28932e4a3d3
2024-06-19 20:42:09 +08:00
hiyouga
c65f7e9bd5 fix jinja template
Former-commit-id: 0ebf2e2ee23918d28b0cbb20ba456732d6eedfbb
2024-06-19 20:03:50 +08:00
hiyouga
3e0fa4a8da fix templates
Former-commit-id: 6f357d59b73309c5955683008632e7f320e7dcb1
2024-06-19 17:44:05 +08:00
hiyouga
235ed85b0f fix bug
Former-commit-id: 412139eaa2fde98ba19e1257d21144382a59f0d6
2024-06-19 03:49:23 +08:00
hiyouga
1ca639a777 use prefix to replace force system
Former-commit-id: 731d9a964f1c3dbfb83825524d697831e691fb9d
2024-06-19 03:39:52 +08:00
hiyouga
e36a994fe6 fix tool formatter, allow parallel function #4362
Former-commit-id: b8f16c976db4ecec1cc8558851c8cbfb6a5b7e9c
2024-06-19 03:23:51 +08:00
hoshi-hiyouga
19ffcfea76 Merge pull request #4173 from mMrBun/main
Implemented the tool_formatter and tool_extractor for glm4 and Qwen2 tool_format

Former-commit-id: 36b02ceed40198ecd5d559ee4ebef9205442ded2
2024-06-19 03:18:55 +08:00
hiyouga
85f3a09c83 tiny fix
Former-commit-id: bb750fa3dde03ec024ae75596ecd4b884cb126c6
2024-06-18 23:32:18 +08:00
hoshi-hiyouga
60b9a9c1fa Merge pull request #4314 from EliMCosta/patch-2
Fix Dockerfile

Former-commit-id: a123a42d98f5c49446762c1d4cfc674d2e4f61b1
2024-06-18 23:30:59 +08:00
hoshi-hiyouga
984e38575c Merge pull request #4309 from EliMCosta/patch-1
Add Magpie and Webinstruct dataset samples

Former-commit-id: 70966de5d4df51a41fef1da5a919dd622aa9c86c
2024-06-18 23:30:19 +08:00
hiyouga
665df5d733 add deepseek coder v2 #4346
Former-commit-id: d83d3846d8e3bf5c40d4b90c24e2c5909ec61864
2024-06-18 22:53:54 +08:00
hiyouga
4bc0bea0e9 fix #4357
Former-commit-id: a6741bba8cebd16a6a3f97a2dc81057d0e27eb39
2024-06-18 22:42:45 +08:00
hoshi-hiyouga
5cfa342f01 Merge pull request #4334 from zzxzz12345/bugfix/add-pandas-versions
Update requirements.txt

Former-commit-id: 219eb5b346bce7e13c2c3511c1638f9dde595787
2024-06-18 22:30:35 +08:00
hoshi-hiyouga
c106cc24e4 Update requirements.txt
Former-commit-id: da8684f9f0b0103d4fa81279343a48ecd0fcc0cd
2024-06-18 22:27:24 +08:00
hiyouga
372da52d4a fix #4335
Former-commit-id: 2ab449adbb160f339a0586edeb846fa311ad8382
2024-06-18 22:08:56 +08:00
hiyouga
875270b851 lint
Former-commit-id: a19a7ac99af62b6715c96274f6350b124a784331
2024-06-17 22:35:56 +08:00
hiyouga
43fab306b6 update chat engine #4335
Former-commit-id: b163df7de48777e4319c9ccc736b0acdd5f473ed
2024-06-17 19:07:17 +08:00
hiyouga
77242f4169 update readme
Former-commit-id: 07c629f77c3978f339402e578cde1aede3f37699
2024-06-17 18:47:24 +08:00
hiyouga
60d9896a70 fix #4326
Former-commit-id: 3c2c45812a720d92f7f5b15b9f03370fe6bf069e
2024-06-17 18:17:48 +08:00
hiyouga
485a80d294 tiny fix
Former-commit-id: 2289436567a7860d25d9da0afb39e4a3e5e83839
2024-06-17 17:47:25 +08:00
胡翀
63bfe9967e Update requirements.txt
add pandas version requirements

Former-commit-id: ed1cf559aa2d02588aacf55a17b439473651f626
2024-06-17 16:45:57 +08:00
Eli Costa
a720b82e63 Fix Dockerfile
Adds the commands to correctly execute LLama-Factory servers

Former-commit-id: 22af40f0895a6f88709a495febeca8507d41d989
2024-06-16 19:16:23 -03:00
Eli Costa
d3b0048d8c Update README_zh.md
Fix details tag in datasets menus

Former-commit-id: d79c1bd4806e9ea13115fabebf9da2d19b0a52be
2024-06-16 11:34:31 -03:00
Eli Costa
9a0aca42a5 Update README_zh.md
Add Magpie and WebInstruct to README

Former-commit-id: 6cf5323959fe9500ba06ab28980fcc8f62e1373f
2024-06-16 11:22:06 -03:00
Eli Costa
5e802b0645 Update README.md
Add Magpie and Webinstruct to README

Former-commit-id: 2b32b9263f12605e48e11dce9b5fbb746d790745
2024-06-16 11:19:25 -03:00
hoshi-hiyouga
ca67b7a568 Update parser.py
Former-commit-id: d10c97193d08bd368aca1a72f0d1d8a96c76765d
2024-06-16 02:57:00 +08:00
hiyouga
76cd879c84 update pr template
Former-commit-id: 0b7c29674fda10c0ac87e0a0c75990feabb5a3de
2024-06-16 01:43:43 +08:00
hoshi-hiyouga
e0c049e590 Merge pull request #4307 from hiyouga/pissa
Support pissa

Former-commit-id: e7c0eefe96540c106162f5d252476b10b97ae696
2024-06-16 01:41:50 +08:00
hiyouga
727943f078 fix tol
Former-commit-id: bdb54bcb477126687db789bd89f2df84e424a2a3
2024-06-16 01:38:44 +08:00
hiyouga
8393b08666 Update tests.yml
Former-commit-id: 82e83615a706293abbf266d11c57caedafdd4c5b
2024-06-16 01:22:23 +08:00
hiyouga
9049f72d2f increase tol
Former-commit-id: c29071445e34aed23123fdf883a4d877744a1b0e
2024-06-16 01:21:06 +08:00
hiyouga
32f45c9e91 support pissa
Former-commit-id: ef8e45f2eaf466c54e9a671512a2974575677b08
2024-06-16 01:08:12 +08:00
hiyouga
05f3a3c944 tiny fix
Former-commit-id: f7f440986b0ae3b38ea9f2da80789629d4f79ea1
2024-06-16 01:06:41 +08:00
hiyouga
14f7bfc545 use fixture
Former-commit-id: 10761985691b9f934f7689c1f82aa6dd68febcca
2024-06-15 20:06:17 +08:00
hiyouga
7f90b0cd20 add tests
Former-commit-id: 484634ee9c982e82e919ff67d507e0210345182d
2024-06-15 19:51:20 +08:00
hiyouga
308abfec6c add minicpm #4227
Former-commit-id: e1bb18ce60be9a1b203989def30f1b9194286325
2024-06-15 17:58:52 +08:00
hiyouga
bb88536166 add license
Former-commit-id: 69cfc98d7c81756a5ab6bf962240e393e449fef0
2024-06-15 17:54:33 +08:00
hiyouga
d2df3f2d6e update readme
Former-commit-id: a43d302aa79cbfb9b0606e855b4c1af6865d8e68
2024-06-15 05:13:16 +08:00
hiyouga
2abfad9c1f fix #4271
Former-commit-id: 03707e78d29bfcf5d395a64bb38632bdb3ff47ce
2024-06-15 05:11:33 +08:00
hiyouga
2af932d969 disable DP
Former-commit-id: c18fd609d268389f3e65274992045a6c9f8e6c1f
2024-06-15 04:57:19 +08:00
hiyouga
c29fa61a9c fix #4292
Former-commit-id: 4cd4c179d24eab0fcaec2b29b9dd71970f877fe8
2024-06-15 04:47:13 +08:00
hiyouga
a30931fe0f fix #4295
Former-commit-id: 08f657868f9d605b837c5d8c2946a25cc05c8735
2024-06-15 04:34:55 +08:00
hiyouga
3ff9b87012 add test cases
Former-commit-id: 731176ff34cdf0cbf6b41c40c69f4ceb54c2daf6
2024-06-15 04:05:54 +08:00
hiyouga
f4f315fd11 Update README.md
Former-commit-id: f8d701cd3ce2e56f95b4f5439b8b48d5b62e0d2b
2024-06-13 16:02:21 +08:00
hiyouga
530165d9a5 update examples
Former-commit-id: d6bf6231290d79eb3a63e711f18fa711ef18a4f6
2024-06-13 03:26:10 +08:00
hiyouga
dbd1458adf add quant check in webui export tab
Former-commit-id: 6455ca07061ae9858cd7bc996b28be1fde697a3d
2024-06-13 03:19:18 +08:00
hiyouga
dedefecd2b Update llama3_full_sft_ds3.yaml
Former-commit-id: e715af62d521112d9c155cfa91fbb42fa0e77710
2024-06-13 03:16:20 +08:00
hiyouga
46f441dd37 update examples
Former-commit-id: 19681f93db399d695aa8e35f8ec2a9e720875baa
2024-06-13 03:15:06 +08:00
hiyouga
49b58fd6af fix #4221
Former-commit-id: 05a3be4853b941909e7d193c31e8d62c8c5f879b
2024-06-13 02:48:21 +08:00
hiyouga
103a507b39 fix #4209
DeepSpeed ZeRO3 has inflight param error when calling model.eval()


Former-commit-id: 4be013f18ea6a35b5a11db98db5f0670ffb41619
2024-06-13 02:25:50 +08:00
hiyouga
0a75224f62 clean code
Former-commit-id: f54cafd5c7f0383370d1a2f357834a61a97397ce
2024-06-13 01:58:16 +08:00
hoshi-hiyouga
04d7629abf Merge pull request #4246 from hzhaoy/adapt-vllm-v0.5.0
adapt vllm==0.5.0

Former-commit-id: 1068e25fc8b89f11cc79b164ee4aef9ce137ad4c
2024-06-13 01:54:02 +08:00
hiyouga
1b6786a21f add neo-sft dataset
Former-commit-id: 34863fa7cb641ceca92e3a2eec914126db537b62
2024-06-13 01:00:56 +08:00
hiyouga
5080f2314c fix lint
Former-commit-id: b170165679317af2b3f03633afac27661b3deb06
2024-06-13 00:48:44 +08:00
hiyouga
41beb7f0a3 fix docker compose usage
Former-commit-id: 59a5bd5d5c8d2a44e2dad26b74e77a45e109c8d6
2024-06-13 00:07:48 +08:00
hzhaoy
799873aa14 adapt vllm==0.5.0
Former-commit-id: 02afd9ff64f23e6707ac739ae1269f41bd70c340
2024-06-12 18:29:03 +08:00
hiyouga
fe2c7eaa93 update readme
Former-commit-id: a436aaa83f0cf12c8f404459e5486f9369d538ec
2024-06-12 17:39:12 +08:00
hiyouga
6392d45ea7 fix #4242
Former-commit-id: cf260e7af03f49aa5e3d6daf3b27738ff9b9bcb8
2024-06-12 16:50:11 +08:00
hoshi-hiyouga
c60ea675d7 Merge pull request #4234 from kimdwkimdw/patch-1
Support vllm==0.5.0

Former-commit-id: 0a9da057c9e7ef11cd709b20263c3d2e4c2d72ed
2024-06-12 16:39:09 +08:00
Arthur Kim
16c7c92396 Support vllm==0.5.0
Former-commit-id: e7a8ffd7af21bc3759f055033ba2209fa7a1be0e
2024-06-12 16:49:12 +09:00
hoshi-hiyouga
7598b37543 Merge pull request #4204 from dignfei/main
fixbug:llama3在增量预训练时应该使用<|end_of_text|>标识文本的结束

Former-commit-id: e566342636faf0031a0ba5d5dd4fcff8401a2b76
2024-06-11 17:06:10 +08:00
hoshi-hiyouga
cc9717e2f2 Update pretrain.py
Former-commit-id: e2317b2a84149e39fddfd6366be3de23dfb71f82
2024-06-11 17:02:14 +08:00
hiyouga
08f2f99f4b fix deepspeed version
Former-commit-id: 938a69bb07d4de7d82928ff01c582032162c1480
2024-06-11 16:52:36 +08:00
d
77bf3d66c7 经过大量的增量预训练,进行对比试验,发现这个bug:llama3在预训练时使用的tokenizer.eos_toke是'<|end_of_text|>' ,这里在每条数据后面也得用这个,而不是'<|eot_id|>',否则很容易导致严重的性能下降
Former-commit-id: ef470561f742b16eaa0f99c4cadecd7c84ce6bd2
2024-06-11 16:23:40 +08:00
hiyouga
f14f67f803 Update bug-report.yml
Former-commit-id: bb022cd867ebf2593e40fc6ba43b768603b129a3
2024-06-11 15:40:21 +08:00
hiyouga
820b6e7b32 fix #4198
Former-commit-id: 945d2c6cc73542adf9272ebd9aa332ea2c1c7361
2024-06-11 15:38:38 +08:00
hiyouga
27aece94cf tiny fix
Former-commit-id: c4b2e263d9cefbad0fbc5de72422e4ef8edbcb54
2024-06-11 12:48:53 +08:00
hoshi-hiyouga
3f2508be92 Merge pull request #4191 from iamthebot/al--add_manifest_for_reqs
Add MANIFEST.in so requirements.txt is present in sdist

Former-commit-id: fd6d1c3fce855d1ef7396cf33af9f12eadc5a878
2024-06-11 10:41:15 +08:00
Alfredo Luque
fce11bb386 add manifest so requirements.txt in sdist
Former-commit-id: b501a3c56c51786c3006a2aca15a145641a4556c
2024-06-11 00:07:06 +00:00
hiyouga
2723438531 tiny fix
Former-commit-id: b5e9711ef375cc323fc083e742cccfc974550416
2024-06-11 01:04:16 +08:00
hiyouga
f330b73682 set dev version
Former-commit-id: 16c47cc15226119e33e46ba0f2f6ccb37072257f
2024-06-11 00:50:53 +08:00
mMrBun
bc04ca464a Optimize the handling of QWEN2 in scenarios involving multiple tool calls.
Former-commit-id: 48f870edc96ada40360f7e6e67cbf58805295b33
2024-06-10 02:00:14 +08:00
mMrBun
44829df762 Removed unnecessary comments.
Former-commit-id: 2b81252aa693871098931cd7873ef83ef4922ba5
2024-06-09 18:25:22 +08:00
mMrBun
94ddfa66c0 Merge branch 'hiyouga:main' into main
Former-commit-id: c25734d874a36222e0a540a2c994bbda73008b27
2024-06-09 18:17:24 +08:00
mMrBun
8db8ed5a41 Implemented the tool_formatter and tool_extractor for glm4 tool_format
Former-commit-id: db7fa4490ea7f6966418d2879c895cbc1763b16d
2024-06-09 18:16:15 +08:00
164 changed files with 3364 additions and 660 deletions

View File

@@ -38,7 +38,9 @@ body:
请合理使用 Markdown 标签来格式化您的文本。 请合理使用 Markdown 标签来格式化您的文本。
placeholder: | placeholder: |
```bash
llamafactory-cli train ... llamafactory-cli train ...
```
- type: textarea - type: textarea
id: expected-behavior id: expected-behavior

View File

@@ -5,3 +5,4 @@ Fixes # (issue)
## Before submitting ## Before submitting
- [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)? - [ ] Did you read the [contributor guideline](https://github.com/hiyouga/LLaMA-Factory/blob/main/.github/CONTRIBUTING.md)?
- [ ] Did you write any new necessary tests?

40
.github/workflows/publish.yml vendored Normal file
View File

@@ -0,0 +1,40 @@
name: publish
on:
release:
types:
- published
jobs:
publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
environment:
name: release
url: https://pypi.org/p/llamafactory
permissions:
id-token: write
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install build
- name: Build package
run: |
python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -9,8 +9,6 @@ on:
- "requirements.txt" - "requirements.txt"
- ".github/workflows/*.yml" - ".github/workflows/*.yml"
pull_request: pull_request:
types:
- review_requested
branches: branches:
- main - main
paths: paths:

View File

@@ -32,7 +32,7 @@ RUN EXTRA_PACKAGES="metrics"; \
EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \ EXTRA_PACKAGES="${EXTRA_PACKAGES},deepspeed"; \
fi; \ fi; \
pip install -e .[$EXTRA_PACKAGES] && \ pip install -e .[$EXTRA_PACKAGES] && \
pip uninstall -y transformer-engine pip uninstall -y transformer-engine flash-attn
# Set up volumes # Set up volumes
VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ] VOLUME [ "/root/.cache/huggingface/", "/app/data", "/app/output" ]
@@ -42,3 +42,6 @@ EXPOSE 7860
# Expose port 8000 for the API service # Expose port 8000 for the API service
EXPOSE 8000 EXPOSE 8000
# Launch LLaMA Board
CMD [ "llamafactory-cli", "webui" ]

1
MANIFEST.in Normal file
View File

@@ -0,0 +1 @@
include LICENSE requirements.txt

View File

@@ -11,4 +11,4 @@ style:
ruff format $(check_dirs) ruff format $(check_dirs)
test: test:
pytest tests/ CUDA_VISIBLE_DEVICES= pytest tests/

View File

@@ -49,7 +49,7 @@ Choose your path:
- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. - **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
- **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc.
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8. - **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ and Agent tuning. - **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning.
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA. - **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc. - **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker. - **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
@@ -71,9 +71,9 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/06/07] We supported fine-tuning the **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** series models. [24/06/16] We support **[PiSSA](https://arxiv.org/abs/2404.02948)** algorithm. See [examples](examples/README.md) for usage.
[24/06/05] We supported fine-tuning the **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** models. [24/06/07] We supported fine-tuning the **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** and **[GLM-4](https://github.com/THUDM/GLM-4)** models.
[24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage. [24/05/26] We supported **[SimPO](https://arxiv.org/abs/2405.14734)** algorithm for preference learning. See [examples](examples/README.md) for usage.
@@ -151,35 +151,35 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Supported Models ## Supported Models
| Model | Model size | Template | | Model | Model size | Template |
| -------------------------------------------------------- | -------------------------------- | --------- | | --------------------------------------------------------- | -------------------------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma](https://huggingface.co/google) | 3B | gemma | | [PaliGemma](https://huggingface.co/google) | 3B | gemma |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen | | [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi | | [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | | [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models.
@@ -259,6 +259,9 @@ You also can add a custom chat template to [template.py](src/llamafactory/data/t
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction) - [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo) - [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@@ -405,7 +408,7 @@ Please refer to [data/README.md](data/README.md) for checking the details about
Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively. Use the following 3 commands to run LoRA **fine-tuning**, **inference** and **merging** of the Llama3-8B-Instruct model, respectively.
```bash ```bash
llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
``` ```
@@ -423,6 +426,8 @@ llamafactory-cli webui
### Build Docker ### Build Docker
#### Use Docker
```bash ```bash
docker build -f ./Dockerfile \ docker build -f ./Dockerfile \
--build-arg INSTALL_BNB=false \ --build-arg INSTALL_BNB=false \
@@ -442,8 +447,12 @@ docker run -it --gpus=all \
llamafactory:latest llamafactory:latest
``` ```
> [!TIP] #### Use Docker Compose
> Use Docker Compose to build image via `docker compose up -d`.
```bash
docker-compose up -d
docker-compose exec llamafactory bash
```
<details><summary>Details about volume</summary> <details><summary>Details about volume</summary>
@@ -456,7 +465,7 @@ docker run -it --gpus=all \
### Deploy with OpenAI-style API and vLLM ### Deploy with OpenAI-style API and vLLM
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
``` ```
> [!TIP] > [!TIP]
@@ -474,7 +483,7 @@ Train the model by specifying a model ID of the ModelScope Hub as the `model_nam
### Use W&B Logger ### Use W&B Logger
To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments. To use [Weights & Biases](https://wandb.ai) for logging experimental results, you need to add the following arguments to yaml files.
```yaml ```yaml
report_to: wandb report_to: wandb

View File

@@ -49,7 +49,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 - **多种模型**LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
- **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **集成方法**增量预训练、多模态指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。
- **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。 - **多种精度**32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
- **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ 和 Agent 微调。 - **先进算法**GaLore、BAdam、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。
- **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。 - **实用技巧**FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
- **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。 - **实验监控**LlamaBoard、TensorBoard、Wandb、MLflow 等等。
- **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。 - **极速推理**:基于 vLLM 的 OpenAI 风格 API、浏览器界面和命令行接口。
@@ -71,9 +71,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 更新日志 ## 更新日志
[24/06/07] 我们支持了 **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** 系列模型的微调 [24/06/16] 我们支持了 **[PiSSA](https://arxiv.org/abs/2404.02948)** 算法。详细用法请参照 [examples](examples/README_zh.md)
[24/06/05] 我们支持了 **[GLM-4-9B/GLM-4-9B-Chat](https://github.com/THUDM/GLM-4)** 模型的微调。 [24/06/07] 我们支持了 **[Qwen2](https://qwenlm.github.io/blog/qwen2/)** 和 **[GLM-4](https://github.com/THUDM/GLM-4)** 模型的微调。
[24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。 [24/05/26] 我们支持了 **[SimPO](https://arxiv.org/abs/2405.14734)** 偏好对齐算法。详细用法请参照 [examples](examples/README_zh.md)。
@@ -151,35 +151,35 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
## 模型 ## 模型
| 模型名 | 模型大小 | Template | | 模型名 | 模型大小 | Template |
| -------------------------------------------------------- | -------------------------------- | --------- | | --------------------------------------------------------- | -------------------------------- | --------- |
| [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | | [Baichuan2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 |
| [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOM](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | | [BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - |
| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | | [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 |
| [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | | [Command-R](https://huggingface.co/CohereForAI) | 35B/104B | cohere |
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | | [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma | | [Gemma/CodeGemma](https://huggingface.co/google) | 2B/7B | gemma |
| [GLM4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM4](https://huggingface.co/THUDM) | 9B | glm4 |
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 | | [InternLM2](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [LLaMA-3](https://huggingface.co/meta-llama) | 8B/70B | llama3 |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | vicuna |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma](https://huggingface.co/google) | 3B | gemma | | [PaliGemma](https://huggingface.co/google) | 3B | gemma |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | - |
| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | | [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | qwen |
| [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen | | [Qwen1.5 (Code/MoE)](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/32B/72B/110B | qwen |
| [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen | | [Qwen2 (MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/7B/57B/72B | qwen |
| [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [StarCoder2](https://huggingface.co/bigcode) | 3B/7B/15B | - |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse |
| [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi | | [Yi (1/1.5)](https://huggingface.co/01-ai) | 6B/9B/34B | yi |
| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | | [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | | [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan |
> [!NOTE] > [!NOTE]
> 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。 > 对于所有“基座”Base模型`template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Instruct/Chat模型请务必使用**对应的模板**。
@@ -259,6 +259,9 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd
- [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia) - [Cosmopedia (en)](https://huggingface.co/datasets/HuggingFaceTB/cosmopedia)
- [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction) - [STEM (zh)](https://huggingface.co/datasets/hfl/stem_zh_instruction)
- [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo) - [Ruozhiba (zh)](https://huggingface.co/datasets/hfl/ruozhiba_gpt4_turbo)
- [Neo-sft (zh)](https://huggingface.co/datasets/m-a-p/neo_sft_phase2)
- [WebInstructSub (en)](https://huggingface.co/datasets/TIGER-Lab/WebInstructSub)
- [Magpie-Pro-300K-Filtered (en)](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k) - [LLaVA mixed (en&zh)](https://huggingface.co/datasets/BUAADreamer/llava-en-zh-300k)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de) - [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de) - [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
@@ -405,7 +408,7 @@ Docker 镜像:
下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。 下面三行命令分别对 Llama3-8B-Instruct 模型进行 LoRA **微调**、**推理**和**合并**。
```bash ```bash
llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
``` ```
@@ -423,6 +426,8 @@ llamafactory-cli webui
### 构建 Docker ### 构建 Docker
#### 使用 Docker
```bash ```bash
docker build -f ./Dockerfile \ docker build -f ./Dockerfile \
--build-arg INSTALL_BNB=false \ --build-arg INSTALL_BNB=false \
@@ -442,8 +447,12 @@ docker run -it --gpus=all \
llamafactory:latest llamafactory:latest
``` ```
> [!TIP] #### 使用 Docker Compose
> 通过 `docker compose up -d` 使用 Docker Compose 构建镜像。
```bash
docker-compose up -d
docker-compose exec llamafactory bash
```
<details><summary>数据卷详情</summary> <details><summary>数据卷详情</summary>
@@ -456,7 +465,7 @@ docker run -it --gpus=all \
### 利用 vLLM 部署 OpenAI API ### 利用 vLLM 部署 OpenAI API
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1 API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
``` ```
> [!TIP] > [!TIP]
@@ -474,7 +483,7 @@ export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
### 使用 W&B 面板 ### 使用 W&B 面板
若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请添加下面的参数。 若要使用 [Weights & Biases](https://wandb.ai) 记录实验数据,请在 yaml 文件中添加下面的参数。
```yaml ```yaml
report_to: wandb report_to: wandb

View File

@@ -1,5 +1,3 @@
version: '3.8'
services: services:
llamafactory: llamafactory:
build: build:
@@ -19,6 +17,9 @@ services:
- "7860:7860" - "7860:7860"
- "8000:8000" - "8000:8000"
ipc: host ipc: host
tty: true
stdin_open: true
command: bash
deploy: deploy:
resources: resources:
reservations: reservations:

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import datasets import datasets

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import datasets import datasets

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import datasets import datasets

View File

@@ -4,59 +4,59 @@ Make sure to execute these commands in the `LLaMA-Factory` directory.
## Table of Contents ## Table of Contents
- [LoRA Fine-Tuning on A Single GPU](#lora-fine-tuning-on-a-single-gpu) - [LoRA Fine-Tuning](#lora-fine-tuning)
- [QLoRA Fine-Tuning on a Single GPU](#qlora-fine-tuning-on-a-single-gpu) - [QLoRA Fine-Tuning](#qlora-fine-tuning)
- [LoRA Fine-Tuning on Multiple GPUs](#lora-fine-tuning-on-multiple-gpus) - [Full-Parameter Fine-Tuning](#full-parameter-fine-tuning)
- [LoRA Fine-Tuning on Multiple NPUs](#lora-fine-tuning-on-multiple-npus)
- [Full-Parameter Fine-Tuning on Multiple GPUs](#full-parameter-fine-tuning-on-multiple-gpus)
- [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization) - [Merging LoRA Adapters and Quantization](#merging-lora-adapters-and-quantization)
- [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models) - [Inferring LoRA Fine-Tuned Models](#inferring-lora-fine-tuned-models)
- [Extras](#extras) - [Extras](#extras)
Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
## Examples ## Examples
### LoRA Fine-Tuning on A Single GPU ### LoRA Fine-Tuning
#### (Continuous) Pre-Training #### (Continuous) Pre-Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
``` ```
#### Supervised Fine-Tuning #### Supervised Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
#### Multimodal Supervised Fine-Tuning #### Multimodal Supervised Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
``` ```
#### Reward Modeling #### Reward Modeling
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
``` ```
#### PPO Training #### PPO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### DPO/ORPO/SimPO Training #### DPO/ORPO/SimPO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
``` ```
#### KTO Training #### KTO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
``` ```
#### Preprocess Dataset #### Preprocess Dataset
@@ -64,95 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset. It is useful for large dataset, use `tokenized_path` in config to load the preprocessed dataset.
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
``` ```
#### Evaluating on MMLU/CMMLU/C-Eval Benchmarks #### Evaluating on MMLU/CMMLU/C-Eval Benchmarks
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
``` ```
#### Batch Predicting and Computing BLEU and ROUGE Scores #### Batch Predicting and Computing BLEU and ROUGE Scores
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
```
### QLoRA Fine-Tuning on a Single GPU
#### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml
```
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
```
#### Supervised Fine-Tuning with 4-bit AWQ Quantization
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
```
#### Supervised Fine-Tuning with 2-bit AQLM Quantization
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
```
### LoRA Fine-Tuning on Multiple GPUs
#### Supervised Fine-Tuning on Single Node
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
``` ```
#### Supervised Fine-Tuning on Multiple Nodes #### Supervised Fine-Tuning on Multiple Nodes
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding) #### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
``` ```
### LoRA Fine-Tuning on Multiple NPUs ### QLoRA Fine-Tuning
#### Supervised Fine-Tuning with DeepSpeed ZeRO-0 #### Supervised Fine-Tuning with 4/8-bit Bitsandbytes Quantization (Recommended)
```bash ```bash
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
``` ```
### Full-Parameter Fine-Tuning on Multiple GPUs #### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
```
#### Supervised Fine-Tuning with 4-bit AWQ Quantization
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
```
#### Supervised Fine-Tuning with 2-bit AQLM Quantization
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
```
### Full-Parameter Fine-Tuning
#### Supervised Fine-Tuning on Single Node #### Supervised Fine-Tuning on Single Node
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
``` ```
#### Supervised Fine-Tuning on Multiple Nodes #### Supervised Fine-Tuning on Multiple Nodes
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
``` ```
#### Batch Predicting and Computing BLEU and ROUGE Scores #### Batch Predicting and Computing BLEU and ROUGE Scores
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml llamafactory-cli train examples/train_full/llama3_full_predict.yaml
``` ```
### Merging LoRA Adapters and Quantization ### Merging LoRA Adapters and Quantization
@@ -162,35 +146,33 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llam
Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters. Note: DO NOT use quantized model or `quantization_bit` when merging LoRA adapters.
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
``` ```
#### Quantizing Model using AutoGPTQ #### Quantizing Model using AutoGPTQ
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
``` ```
### Inferring LoRA Fine-Tuned Models ### Inferring LoRA Fine-Tuned Models
Use `CUDA_VISIBLE_DEVICES=0,1` to infer models on multiple devices.
#### Use CLI #### Use CLI
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
``` ```
#### Use Web UI #### Use Web UI
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
``` ```
#### Launch OpenAI-style API #### Launch OpenAI-style API
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml llamafactory-cli api examples/inference/llama3_lora_sft.yaml
``` ```
### Extras ### Extras
@@ -198,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
#### Full-Parameter Fine-Tuning using GaLore #### Full-Parameter Fine-Tuning using GaLore
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
``` ```
#### Full-Parameter Fine-Tuning using BAdam #### Full-Parameter Fine-Tuning using BAdam
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
``` ```
#### LoRA+ Fine-Tuning #### LoRA+ Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
```
#### PiSSA Fine-Tuning
```bash
llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
``` ```
#### Mixture-of-Depths Fine-Tuning #### Mixture-of-Depths Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
``` ```
#### LLaMA-Pro Fine-Tuning #### LLaMA-Pro Fine-Tuning
```bash ```bash
bash examples/extras/llama_pro/expand.sh bash examples/extras/llama_pro/expand.sh
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
``` ```
#### FSDP+QLoRA Fine-Tuning #### FSDP+QLoRA Fine-Tuning
```bash ```bash
bash examples/extras/fsdp_qlora/single_node.sh bash examples/extras/fsdp_qlora/train.sh
``` ```

View File

@@ -4,59 +4,59 @@
## 目录 ## 目录
- [单 GPU LoRA 微调](#单-gpu-lora-微调) - [LoRA 微调](#lora-微调)
- [单 GPU QLoRA 微调](#单-gpu-qlora-微调) - [QLoRA 微调](#qlora-微调)
- [多 GPU LoRA 微调](#多-gpu-lora-微调) - [全参数微调](#全参数微调)
- [多 NPU LoRA 微调](#多-npu-lora-微调)
- [多 GPU 全参数微调](#多-gpu-全参数微调)
- [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化) - [合并 LoRA 适配器与模型量化](#合并-lora-适配器与模型量化)
- [推理 LoRA 模型](#推理-lora-模型) - [推理 LoRA 模型](#推理-lora-模型)
- [杂项](#杂项) - [杂项](#杂项)
使用 `CUDA_VISIBLE_DEVICES`GPU`ASCEND_RT_VISIBLE_DEVICES`NPU选择计算设备。
## 示例 ## 示例
### 单 GPU LoRA 微调 ### LoRA 微调
#### (增量)预训练 #### (增量)预训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_pretrain.yaml llamafactory-cli train examples/train_lora/llama3_lora_pretrain.yaml
``` ```
#### 指令监督微调 #### 指令监督微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_sft.yaml llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` ```
#### 多模态指令监督微调 #### 多模态指令监督微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llava1_5_lora_sft.yaml llamafactory-cli train examples/train_lora/llava1_5_lora_sft.yaml
``` ```
#### 奖励模型训练 #### 奖励模型训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_reward.yaml llamafactory-cli train examples/train_lora/llama3_lora_reward.yaml
``` ```
#### PPO 训练 #### PPO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_ppo.yaml llamafactory-cli train examples/train_lora/llama3_lora_ppo.yaml
``` ```
#### DPO/ORPO/SimPO 训练 #### DPO/ORPO/SimPO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_dpo.yaml llamafactory-cli train examples/train_lora/llama3_lora_dpo.yaml
``` ```
#### KTO 训练 #### KTO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_kto.yaml llamafactory-cli train examples/train_lora/llama3_lora_kto.yaml
``` ```
#### 预处理数据集 #### 预处理数据集
@@ -64,95 +64,79 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lo
对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。 对于大数据集有帮助,在配置中使用 `tokenized_path` 以加载预处理后的数据集。
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_preprocess.yaml llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
``` ```
#### 在 MMLU/CMMLU/C-Eval 上评估 #### 在 MMLU/CMMLU/C-Eval 上评估
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli eval examples/lora_single_gpu/llama3_lora_eval.yaml llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
``` ```
#### 批量预测并计算 BLEU 和 ROUGE 分数 #### 批量预测并计算 BLEU 和 ROUGE 分数
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/lora_single_gpu/llama3_lora_predict.yaml llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
``` ```
### 单 GPU QLoRA 微调 #### 多机指令监督微调
#### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_bitsandbytes.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
``` FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_gptq.yaml
```
#### 基于 4 比特 AWQ 量化进行指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_awq.yaml
```
#### 基于 2 比特 AQLM 量化进行指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/qlora_single_gpu/llama3_lora_sft_aqlm.yaml
```
### 多 GPU LoRA 微调
#### 在单机上进行指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
```
#### 在多机上进行指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft.yaml
``` ```
#### 使用 DeepSpeed ZeRO-3 平均分配显存 #### 使用 DeepSpeed ZeRO-3 平均分配显存
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_gpu/llama3_lora_sft_ds.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
``` ```
### 多 NPU LoRA 微调 ### QLoRA 微调
#### 使用 DeepSpeed ZeRO-0 进行指令监督微调 #### 基于 4/8 比特 Bitsandbytes 量化进行指令监督微调(推荐)
```bash ```bash
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/lora_multi_npu/llama3_lora_sft_ds.yaml llamafactory-cli train examples/train_qlora/llama3_lora_sft_bitsandbytes.yaml
``` ```
### 多 GPU 全参数微调 #### 基于 4/8 比特 GPTQ 量化进行指令监督微调
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_gptq.yaml
```
#### 基于 4 比特 AWQ 量化进行指令监督微调
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_awq.yaml
```
#### 基于 2 比特 AQLM 量化进行指令监督微调
```bash
llamafactory-cli train examples/train_qlora/llama3_lora_sft_aqlm.yaml
```
### 全参数微调
#### 在单机上进行指令监督微调 #### 在单机上进行指令监督微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
``` ```
#### 在多机上进行指令监督微调 #### 在多机上进行指令监督微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/full_multi_gpu/llama3_full_sft.yaml FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft_ds3.yaml
``` ```
#### 批量预测并计算 BLEU 和 ROUGE 分数 #### 批量预测并计算 BLEU 和 ROUGE 分数
```bash ```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llama3_full_predict.yaml llamafactory-cli train examples/train_full/llama3_full_predict.yaml
``` ```
### 合并 LoRA 适配器与模型量化 ### 合并 LoRA 适配器与模型量化
@@ -162,35 +146,33 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 llamafactory-cli train examples/full_multi_gpu/llam
注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。 注:请勿使用量化后的模型或 `quantization_bit` 参数来合并 LoRA 适配器。
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml llamafactory-cli export examples/merge_lora/llama3_lora_sft.yaml
``` ```
#### 使用 AutoGPTQ 量化模型 #### 使用 AutoGPTQ 量化模型
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli export examples/merge_lora/llama3_gptq.yaml llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
``` ```
### 推理 LoRA 模型 ### 推理 LoRA 模型
使用 `CUDA_VISIBLE_DEVICES=0,1` 进行多卡推理。
#### 使用命令行接口 #### 使用命令行接口
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli chat examples/inference/llama3_lora_sft.yaml llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
``` ```
#### 使用浏览器界面 #### 使用浏览器界面
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
``` ```
#### 启动 OpenAI 风格 API #### 启动 OpenAI 风格 API
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.yaml llamafactory-cli api examples/inference/llama3_lora_sft.yaml
``` ```
### 杂项 ### 杂项
@@ -198,36 +180,42 @@ CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/llama3_lora_sft.y
#### 使用 GaLore 进行全参数训练 #### 使用 GaLore 进行全参数训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml llamafactory-cli train examples/extras/galore/llama3_full_sft.yaml
``` ```
#### 使用 BAdam 进行全参数训练 #### 使用 BAdam 进行全参数训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml llamafactory-cli train examples/extras/badam/llama3_full_sft.yaml
``` ```
#### LoRA+ 微调 #### LoRA+ 微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml llamafactory-cli train examples/extras/loraplus/llama3_lora_sft.yaml
```
#### PiSSA 微调
```bash
llamafactory-cli train examples/extras/pissa/llama3_lora_sft.yaml
``` ```
#### 深度混合微调 #### 深度混合微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml llamafactory-cli train examples/extras/mod/llama3_full_sft.yaml
``` ```
#### LLaMA-Pro 微调 #### LLaMA-Pro 微调
```bash ```bash
bash examples/extras/llama_pro/expand.sh bash examples/extras/llama_pro/expand.sh
CUDA_VISIBLE_DEVICES=0 llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
``` ```
#### FSDP+QLoRA 微调 #### FSDP+QLoRA 微调
```bash ```bash
bash examples/extras/fsdp_qlora/single_node.sh bash examples/extras/fsdp_qlora/train.sh
``` ```

View File

@@ -8,9 +8,6 @@ do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: all lora_target: all
### ddp
ddp_timeout: 180000000
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
template: llama3 template: llama3
@@ -34,6 +31,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -32,6 +32,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -31,6 +31,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -31,6 +31,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
pure_bf16: true pure_bf16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -6,9 +6,9 @@ stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: all lora_target: all
pissa_init: true
### ddp pissa_iter: 4
ddp_timeout: 180000000 pissa_convert: true
### dataset ### dataset
dataset: identity,alpaca_en_demo dataset: identity,alpaca_en_demo
@@ -27,12 +27,13 @@ overwrite_output_dir: true
### train ### train
per_device_train_batch_size: 1 per_device_train_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 8
learning_rate: 1.0e-4 learning_rate: 1.0e-4
num_train_epochs: 3.0 num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -5,9 +5,6 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
stage: sft stage: sft
do_train: true do_train: true
finetuning_type: full finetuning_type: full
### ddp
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json deepspeed: examples/deepspeed/ds_z3_config.json
### dataset ### dataset
@@ -33,6 +30,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -32,6 +32,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -6,6 +6,7 @@ stage: kto
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: all lora_target: all
pref_beta: 0.1
### dataset ### dataset
dataset: kto_en_demo dataset: kto_en_demo
@@ -30,6 +31,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -31,6 +31,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### generate ### generate
max_new_tokens: 512 max_new_tokens: 512

View File

@@ -22,3 +22,4 @@ overwrite_output_dir: true
### eval ### eval
per_device_eval_batch_size: 1 per_device_eval_batch_size: 1
predict_with_generate: true predict_with_generate: true
ddp_timeout: 180000000

View File

@@ -29,6 +29,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -30,6 +30,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -30,6 +30,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -6,9 +6,6 @@ stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: all lora_target: all
### ddp
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z0_config.json deepspeed: examples/deepspeed/ds_z0_config.json
### dataset ### dataset
@@ -34,6 +31,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -6,9 +6,6 @@ stage: sft
do_train: true do_train: true
finetuning_type: lora finetuning_type: lora
lora_target: all lora_target: all
### ddp
ddp_timeout: 180000000
deepspeed: examples/deepspeed/ds_z3_config.json deepspeed: examples/deepspeed/ds_z3_config.json
### dataset ### dataset
@@ -34,6 +31,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -31,6 +31,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -30,6 +30,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -30,6 +30,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -31,6 +31,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -30,6 +30,7 @@ num_train_epochs: 3.0
lr_scheduler_type: cosine lr_scheduler_type: cosine
warmup_ratio: 0.1 warmup_ratio: 0.1
fp16: true fp16: true
ddp_timeout: 180000000
### eval ### eval
val_size: 0.1 val_size: 0.1

View File

@@ -4,6 +4,7 @@ accelerate>=0.30.1
peft>=0.11.1 peft>=0.11.1
trl>=0.8.6 trl>=0.8.6
gradio>=4.0.0 gradio>=4.0.0
pandas>=2.0.0
scipy scipy
einops einops
sentencepiece sentencepiece

View File

@@ -1,7 +1,20 @@
# coding=utf-8 # coding=utf-8
# Calculates the flops of pre-trained models. # Copyright 2024 Microsoft Corporation and the LlamaFactory team.
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512 #
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/ # This code is inspired by the Microsoft's DeepSpeed library.
# https://www.deepspeed.ai/tutorials/flops-profiler/
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fire import fire
import torch import torch
@@ -17,6 +30,10 @@ def calculate_flops(
seq_length: int = 256, seq_length: int = 256,
flash_attn: str = "auto", flash_attn: str = "auto",
): ):
r"""
Calculates the flops of pre-trained models.
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
"""
with get_accelerator().device(0): with get_accelerator().device(0):
chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn)) chat_model = ChatModel(dict(model_name_or_path=model_name_or_path, template="empty", flash_attn=flash_attn))
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device) fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)

View File

@@ -1,7 +1,20 @@
# coding=utf-8 # coding=utf-8
# Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters. # Copyright 2024 imoneoi and the LlamaFactory team.
# Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16 #
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py # This code is inspired by the imoneoi's OpenChat library.
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
from typing import Literal from typing import Literal
@@ -32,6 +45,10 @@ def calculate_lr(
cutoff_len: int = 1024, # i.e. maximum input length during training cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate, is_mistral: bool = False, # mistral model uses a smaller learning rate,
): ):
r"""
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
"""
model_args, data_args, training_args, _, _ = get_train_args( model_args, data_args, training_args, _, _ = get_train_args(
dict( dict(
stage=stage, stage=stage,

View File

@@ -1,6 +1,17 @@
# coding=utf-8 # coding=utf-8
# Calculates the ppl on the dataset of the pre-trained models. # Copyright 2024 the LlamaFactory team.
# Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
from dataclasses import dataclass from dataclasses import dataclass
@@ -56,6 +67,10 @@ def cal_ppl(
max_samples: Optional[int] = None, max_samples: Optional[int] = None,
train_on_prompt: bool = False, train_on_prompt: bool = False,
): ):
r"""
Calculates the ppl on the dataset of the pre-trained models.
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
"""
model_args, data_args, training_args, finetuning_args, _ = get_train_args( model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict( dict(
stage=stage, stage=stage,

View File

@@ -1,6 +1,17 @@
# coding=utf-8 # coding=utf-8
# Calculates the distribution of the input lengths in the dataset. # Copyright 2024 the LlamaFactory team.
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict from collections import defaultdict
@@ -19,6 +30,10 @@ def length_cdf(
template: str = "default", template: str = "default",
interval: int = 1000, interval: int = 1000,
): ):
r"""
Calculates the distribution of the input lengths in the dataset.
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
"""
model_args, data_args, training_args, _, _ = get_train_args( model_args, data_args, training_args, _, _ = get_train_args(
dict( dict(
stage="sft", stage="sft",

View File

@@ -1,7 +1,20 @@
# coding=utf-8 # coding=utf-8
# Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models. # Copyright 2024 Tencent Inc. and the LlamaFactory team.
# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8 #
# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py # This code is inspired by the Tencent's LLaMA-Pro library.
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
@@ -37,6 +50,10 @@ def block_expansion(
shard_size: Optional[str] = "2GB", shard_size: Optional[str] = "2GB",
save_safetensors: Optional[bool] = False, save_safetensors: Optional[bool] = False,
): ):
r"""
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
"""
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
num_layers = getattr(config, "num_hidden_layers") num_layers = getattr(config, "num_hidden_layers")
setattr(config, "num_hidden_layers", num_layers + num_expand) setattr(config, "num_hidden_layers", num_layers + num_expand)
@@ -103,7 +120,7 @@ def block_expansion(
json.dump(index, f, indent=2, sort_keys=True) json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir)) print("Model weights saved in {}".format(output_dir))
print("Fine-tune this model with:") print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir)) print("model_name_or_path: {}".format(output_dir))
print("finetuning_type: freeze") print("finetuning_type: freeze")
print("freeze_trainable_layers: {}".format(num_expand)) print("freeze_trainable_layers: {}".format(num_expand))

View File

@@ -1,8 +1,17 @@
# coding=utf-8 # coding=utf-8
# Converts the Baichuan2-7B model in the same format as LLaMA2-7B. # Copyright 2024 the LlamaFactory team.
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output #
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py # Licensed under the Apache License, Version 2.0 (the "License");
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
@@ -79,6 +88,11 @@ def save_config(input_dir: str, output_dir: str):
def llamafy_baichuan2( def llamafy_baichuan2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
): ):
r"""
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
"""
try: try:
os.makedirs(output_dir, exist_ok=False) os.makedirs(output_dir, exist_ok=False)
except Exception as e: except Exception as e:

View File

@@ -1,7 +1,17 @@
# coding=utf-8 # coding=utf-8
# Converts the Qwen models in the same format as LLaMA2. # Copyright 2024 the LlamaFactory team.
# Usage: python llamafy_qwen.py --input_dir input --output_dir output #
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
@@ -131,6 +141,11 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
def llamafy_qwen( def llamafy_qwen(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
): ):
r"""
Converts the Qwen models in the same format as LLaMA2.
Usage: python llamafy_qwen.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
"""
try: try:
os.makedirs(output_dir, exist_ok=False) os.makedirs(output_dir, exist_ok=False)
except Exception as e: except Exception as e:

View File

@@ -1,14 +1,25 @@
# coding=utf-8 # coding=utf-8
# Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ) # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Usage: python loftq_init.py --model_name_or_path path_to_model --save_dir output_dir #
# Inspired by: https://github.com/huggingface/peft/blob/main/examples/loftq_finetuning/quantize_save_load.py # This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import fire import fire
import torch
import torch.nn as nn
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -17,38 +28,21 @@ if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PreTrainedModel
class Shell(nn.Module):
def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
super().__init__()
self.weight = nn.Parameter(weight, requires_grad=False)
if bias is not None:
self.bias = nn.Parameter(bias, requires_grad=False)
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}:
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
parent_module = model.get_submodule(parent_name)
child_module = getattr(parent_module, child_name)
base_layer = getattr(child_module, "base_layer")
weight = getattr(base_layer, "weight", None)
bias = getattr(base_layer, "bias", None)
setattr(parent_module, child_name, Shell(weight, bias))
print("Model unwrapped.")
def quantize_loftq( def quantize_loftq(
model_name_or_path: str, model_name_or_path: str,
save_dir: str, output_dir: str,
loftq_bits: Optional[int] = 4, loftq_bits: int = 4,
loftq_iter: Optional[int] = 1, loftq_iter: int = 4,
lora_alpha: Optional[int] = None, lora_alpha: int = None,
lora_rank: Optional[int] = 16, lora_rank: int = 16,
lora_target: Optional[str] = "q_proj,v_proj", lora_dropout: float = 0,
save_safetensors: Optional[bool] = False, lora_target: str = "q_proj,v_proj",
save_safetensors: bool = True,
): ):
r"""
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto") model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter) loftq_config = LoftQConfig(loftq_bits=loftq_bits, loftq_iter=loftq_iter)
@@ -57,25 +51,34 @@ def quantize_loftq(
inference_mode=True, inference_mode=True,
r=lora_rank, r=lora_rank,
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2, lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=0.1, lora_dropout=lora_dropout,
target_modules=[name.strip() for name in lora_target.split(",")], target_modules=[name.strip() for name in lora_target.split(",")],
init_lora_weights="loftq", init_lora_weights="loftq",
loftq_config=loftq_config, loftq_config=loftq_config,
) )
# Init LoftQ model # Init LoftQ model
lora_model = get_peft_model(model, lora_config) print("Initializing LoftQ weights, it may be take several minutes, wait patiently.")
base_model: "PreTrainedModel" = lora_model.get_base_model() peft_model = get_peft_model(model, lora_config)
loftq_dir = os.path.join(output_dir, "loftq_init")
# Save LoftQ model # Save LoftQ model
setattr(lora_model.base_model.peft_config["default"], "base_model_name_or_path", save_dir) setattr(peft_model.peft_config["default"], "base_model_name_or_path", output_dir)
setattr(lora_model.base_model.peft_config["default"], "init_lora_weights", True) setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply loftq again
lora_model.save_pretrained(os.path.join(save_dir, "adapters"), safe_serialization=save_safetensors) peft_model.save_pretrained(loftq_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(loftq_dir))
# Save base model # Save base model
unwrap_model(base_model) base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(save_dir, safe_serialization=save_safetensors) base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(save_dir) tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir))
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("adapter_name_or_path: {}".format(loftq_dir))
print("finetuning_type: lora")
print("quantization_bit: {}".format(loftq_bits))
if __name__ == "__main__": if __name__ == "__main__":

82
scripts/pissa_init.py Normal file
View File

@@ -0,0 +1,82 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import TYPE_CHECKING
import fire
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
if TYPE_CHECKING:
from transformers import PreTrainedModel
def quantize_pissa(
model_name_or_path: str,
output_dir: str,
pissa_iter: int = 4,
lora_alpha: int = None,
lora_rank: int = 16,
lora_dropout: float = 0,
lora_target: str = "q_proj,v_proj",
save_safetensors: bool = True,
):
r"""
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype="auto")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=lora_rank,
lora_alpha=lora_alpha if lora_alpha is not None else lora_rank * 2,
lora_dropout=lora_dropout,
target_modules=[name.strip() for name in lora_target.split(",")],
init_lora_weights="pissa" if pissa_iter == -1 else "pissa_niter_{}".format(pissa_iter),
)
# Init PiSSA model
peft_model = get_peft_model(model, lora_config)
pissa_dir = os.path.join(output_dir, "pissa_init")
# Save PiSSA model
setattr(peft_model.peft_config["default"], "init_lora_weights", True) # don't apply pissa again
peft_model.save_pretrained(pissa_dir, safe_serialization=save_safetensors)
print("Adapter weights saved in {}".format(pissa_dir))
# Save base model
base_model: "PreTrainedModel" = peft_model.unload()
base_model.save_pretrained(output_dir, safe_serialization=save_safetensors)
tokenizer.save_pretrained(output_dir)
print("Model weights saved in {}".format(output_dir))
print("- Fine-tune this model with:")
print("model_name_or_path: {}".format(output_dir))
print("adapter_name_or_path: {}".format(pissa_dir))
print("finetuning_type: lora")
print("pissa_init: false")
print("pissa_convert: true")
print("- and optionally with:")
print("quantization_bit: 4")
if __name__ == "__main__":
fire.Fire(quantize_pissa)

View File

@@ -1,3 +1,18 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
from typing import Sequence from typing import Sequence

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import re import re
@@ -23,7 +37,7 @@ extra_require = {
"torch": ["torch>=1.13.1"], "torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"], "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.14.0"], "deepspeed": ["deepspeed>=0.10.0"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"vllm": ["vllm>=0.4.3"], "vllm": ["vllm>=0.4.3"],
"galore": ["galore-torch"], "galore": ["galore-torch"],

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import uvicorn import uvicorn

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Level: api, webui > chat, eval, train > data, model > hparams > extras # Level: api, webui > chat, eval, train > data, model > hparams > extras
from .cli import VERSION from .cli import VERSION

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional from typing import Optional

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64 import base64
import io import io
import json import json
@@ -78,9 +92,11 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls): if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name tool_calls = [
arguments = message.tool_calls[0].function.arguments {"name": tool_call.function.name, "argument": tool_call.function.arguments}
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False) for tool_call in message.tool_calls
]
content = json.dumps(tool_calls, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list): elif isinstance(message.content, list):
for input_item in message.content: for input_item in message.content:
@@ -104,7 +120,7 @@ def _process_request(
if isinstance(tool_list, list) and len(tool_list): if isinstance(tool_list, list) and len(tool_list):
try: try:
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False) tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except Exception: except json.JSONDecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else: else:
tools = None tools = None
@@ -146,15 +162,17 @@ async def create_chat_completion_response(
choices = [] choices = []
for i, response in enumerate(responses): for i, response in enumerate(responses):
if tools: if tools:
result = chat_model.engine.template.format_tools.extract(response.response_text) result = chat_model.engine.template.extract_tool(response.response_text)
else: else:
result = response.response_text result = response.response_text
if isinstance(result, tuple): if isinstance(result, list):
name, arguments = result tool_calls = []
function = Function(name=name, arguments=arguments) for tool in result:
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function) function = Function(name=tool[0], arguments=tool[1])
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call]) tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL finish_reason = Finish.TOOL
else: else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result) response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_engine import BaseEngine from .base_engine import BaseEngine
from .chat_model import ChatModel from .chat_model import ChatModel

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union
@@ -36,11 +50,6 @@ class BaseEngine(ABC):
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ... ) -> None: ...
@abstractmethod
async def start(
self,
) -> None: ...
@abstractmethod @abstractmethod
async def chat( async def chat(
self, self,

View File

@@ -1,3 +1,20 @@
# Copyright 2024 THUDM and the LlamaFactory team.
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio import asyncio
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
@@ -14,7 +31,7 @@ if TYPE_CHECKING:
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None: def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
loop.run_forever() loop.run_forever()
@@ -32,7 +49,6 @@ class ChatModel:
self._loop = asyncio.new_event_loop() self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start() self._thread.start()
asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
def chat( def chat(
self, self,

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import os import os
@@ -45,6 +59,14 @@ class HuggingfaceEngine(BaseEngine):
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
) # must after fixing tokenizer to resize vocab ) # must after fixing tokenizer to resize vocab
self.generating_args = generating_args.to_dict() self.generating_args = generating_args.to_dict()
try:
asyncio.get_event_loop()
except RuntimeError:
logger.warning("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
@staticmethod @staticmethod
def _process_args( def _process_args(
@@ -245,9 +267,6 @@ class HuggingfaceEngine(BaseEngine):
return scores return scores
async def start(self) -> None:
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
@@ -272,7 +291,7 @@ class HuggingfaceEngine(BaseEngine):
image, image,
input_kwargs, input_kwargs,
) )
async with self._semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args) return await loop.run_in_executor(pool, self._chat, *input_args)
@@ -300,7 +319,7 @@ class HuggingfaceEngine(BaseEngine):
image, image,
input_kwargs, input_kwargs,
) )
async with self._semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: with concurrent.futures.ThreadPoolExecutor() as pool:
stream = self._stream_chat(*input_args) stream = self._stream_chat(*input_args)
while True: while True:
@@ -319,6 +338,6 @@ class HuggingfaceEngine(BaseEngine):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
input_args = (self.model, self.tokenizer, batch_input, input_kwargs) input_args = (self.model, self.tokenizer, batch_input, input_kwargs)
async with self._semaphore: async with self.semaphore:
with concurrent.futures.ThreadPoolExecutor() as pool: with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._get_scores, *input_args) return await loop.run_in_executor(pool, self._get_scores, *input_args)

View File

@@ -1,10 +1,24 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.misc import get_device_count from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@@ -13,7 +27,11 @@ from .base_engine import BaseEngine, Response
if is_vllm_available(): if is_vllm_available():
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import MultiModalData
if is_vllm_version_greater_than_0_5():
from vllm.multimodal.image import ImagePixelData
else:
from vllm.sequence import MultiModalData
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -48,7 +66,7 @@ class VllmEngine(BaseEngine):
"model": model_args.model_name_or_path, "model": model_args.model_name_or_path,
"trust_remote_code": True, "trust_remote_code": True,
"download_dir": model_args.cache_dir, "download_dir": model_args.cache_dir,
"dtype": model_args.vllm_dtype, "dtype": model_args.infer_dtype,
"max_model_len": model_args.vllm_maxlen, "max_model_len": model_args.vllm_maxlen,
"tensor_parallel_size": get_device_count() or 1, "tensor_parallel_size": get_device_count() or 1,
"gpu_memory_utilization": model_args.vllm_gpu_util, "gpu_memory_utilization": model_args.vllm_gpu_util,
@@ -106,7 +124,10 @@ class VllmEngine(BaseEngine):
if self.processor is not None and image is not None: # add image features if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"] pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) if is_vllm_version_greater_than_0_5():
multi_modal_data = ImagePixelData(image=pixel_values)
else: # TODO: remove vllm 0.4.3 support
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else: else:
multi_modal_data = None multi_modal_data = None
@@ -162,9 +183,6 @@ class VllmEngine(BaseEngine):
) )
return result_generator return result_generator
async def start(self) -> None:
pass
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import random import random
import subprocess import subprocess

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
from .data_utils import Role, split_dataset from .data_utils import Role, split_dataset
from .loader import get_dataset from .loader import get_dataset

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, Any, Dict, List, Union
@@ -10,6 +24,7 @@ from .data_utils import Role
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments from ..hparams import DataArguments
from .parser import DatasetAttr from .parser import DatasetAttr
@@ -175,7 +190,10 @@ def convert_sharegpt(
def align_dataset( def align_dataset(
dataset: Union["Dataset", "IterableDataset"], dataset_attr: "DatasetAttr", data_args: "DataArguments" dataset: Union["Dataset", "IterableDataset"],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""
Aligned dataset: Aligned dataset:
@@ -208,7 +226,7 @@ def align_dataset(
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache), load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Converting format of dataset", desc="Converting format of dataset",
) )

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Sequence from typing import Any, Dict, Sequence

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Tuple, Union

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -8,21 +22,23 @@ from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Uni
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
JSON_FORMAT_PROMPT = ( DEFAULT_TOOL_PROMPT = (
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
)
TOOL_SYSTEM_PROMPT = (
"You have access to the following tools:\n{tool_text}" "You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n" "Use the following format if using a tool:\n"
"```\n" "```\n"
"Action: tool name (one of [{tool_names}]).\n" "Action: tool name (one of [{tool_names}]).\n"
"Action Input: the input to the tool{format_prompt}.\n" "Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
"```\n" "```\n"
) )
GLM4_TOOL_PROMPT = (
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}"
)
def default_tool_formatter(tools: List[Dict[str, Any]]) -> str: def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
tool_names = [] tool_names = []
@@ -48,36 +64,60 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
) )
tool_names.append(tool["name"]) tool_names.append(tool["name"])
return TOOL_SYSTEM_PROMPT.format( return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
)
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL) regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match = re.search(regex, content) action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match: if not action_match:
return content return content
tool_name = action_match.group(1).strip() results = []
tool_input = action_match.group(2).strip().strip('"').strip("```") for match in action_match:
tool_name = match[0].strip()
tool_input = match[1].strip().strip('"').strip("```")
try:
arguments = json.loads(tool_input)
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
except json.JSONDecodeError:
return content
return results
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
)
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
if "\n" not in content:
return content
tool_name, tool_input = content.split("\n", maxsplit=1)
try: try:
arguments = json.loads(tool_input) arguments = json.loads(tool_input)
except json.JSONDecodeError: except json.JSONDecodeError:
return content return content
return tool_name, json.dumps(arguments, ensure_ascii=False) return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
@dataclass @dataclass
class Formatter(ABC): class Formatter(ABC):
slots: SLOTS = field(default_factory=list) slots: SLOTS = field(default_factory=list)
tool_format: Optional[Literal["default"]] = None tool_format: Optional[Literal["default", "glm4"]] = None
@abstractmethod @abstractmethod
def apply(self, **kwargs) -> SLOTS: ... def apply(self, **kwargs) -> SLOTS: ...
def extract(self, content: str) -> Union[str, Tuple[str, str]]: def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
raise NotImplementedError raise NotImplementedError
@@ -140,22 +180,28 @@ class FunctionFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
functions: List[Tuple[str, str]] = []
try: try:
function = json.loads(content) tool_calls = json.loads(content)
name = function["name"] if not isinstance(tool_calls, list): # parallel function call
arguments = json.dumps(function["arguments"], ensure_ascii=False) tool_calls = [tool_calls]
except Exception:
name, arguments = "", "" for tool_call in tool_calls:
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError:
functions = []
elements = [] elements = []
for slot in self.slots: for name, arguments in functions:
if isinstance(slot, str): for slot in self.slots:
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) if isinstance(slot, str):
elements.append(slot) slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elif isinstance(slot, (dict, set)): elements.append(slot)
elements.append(slot) elif isinstance(slot, (dict, set)):
else: elements.append(slot)
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements return elements
@@ -163,25 +209,22 @@ class FunctionFormatter(Formatter):
@dataclass @dataclass
class ToolFormatter(Formatter): class ToolFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
if self.tool_format is None: if self.tool_format == "default":
self._tool_formatter = default_tool_formatter
self._tool_extractor = default_tool_extractor
elif self.tool_format == "glm4":
self._tool_formatter = glm4_tool_formatter
self._tool_extractor = glm4_tool_extractor
else:
raise ValueError("Tool format was not found.") raise ValueError("Tool format was not found.")
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
try: try:
tools = json.loads(content) tools = json.loads(content)
if not len(tools): return [self._tool_formatter(tools) if len(tools) != 0 else ""]
return [""] except json.JSONDecodeError:
if self.tool_format == "default":
return [default_tool_formatter(tools)]
else:
raise NotImplementedError
except Exception:
return [""] return [""]
def extract(self, content: str) -> Union[str, Tuple[str, str]]: def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
if self.tool_format == "default": return self._tool_extractor(content)
return default_tool_extractor(content)
else:
raise NotImplementedError

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect import inspect
import os import os
import sys import sys
@@ -18,8 +32,7 @@ from .template import get_template_and_fix_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from transformers import ProcessorMixin, Seq2SeqTrainingArguments from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments, ModelArguments from ..hparams import DataArguments, ModelArguments
from .parser import DatasetAttr from .parser import DatasetAttr
@@ -32,6 +45,7 @@ def load_single_dataset(
dataset_attr: "DatasetAttr", dataset_attr: "DatasetAttr",
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
@@ -123,7 +137,7 @@ def load_single_dataset(
max_samples = min(data_args.max_samples, len(dataset)) max_samples = min(data_args.max_samples, len(dataset))
dataset = dataset.select(range(max_samples)) dataset = dataset.select(range(max_samples))
return align_dataset(dataset, dataset_attr, data_args) return align_dataset(dataset, dataset_attr, data_args, training_args)
def get_dataset( def get_dataset(
@@ -157,7 +171,8 @@ def get_dataset(
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True): if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.") raise ValueError("The dataset is not applicable in the current training stage.")
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args, training_args))
dataset = merge_dataset(all_datasets, data_args, training_args) dataset = merge_dataset(all_datasets, data_args, training_args)
with training_args.main_process_first(desc="pre-process dataset"): with training_args.main_process_first(desc="pre-process dataset"):
@@ -169,7 +184,7 @@ def get_dataset(
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
num_proc=data_args.preprocessing_num_workers, num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=(not data_args.overwrite_cache), load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
desc="Running tokenizer on dataset", desc="Running tokenizer on dataset",
) )

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
@@ -13,8 +27,7 @@ from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsu
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin, Seq2SeqTrainingArguments from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from ..hparams import DataArguments from ..hparams import DataArguments
from .template import Template from .template import Template

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@@ -6,8 +20,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
from ..template import Template from ..template import Template

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
@@ -6,8 +20,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
from ..template import Template from ..template import Template

View File

@@ -1,9 +1,26 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
@@ -12,7 +29,8 @@ def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token
text_examples = [messages[0]["content"] + eos_token for messages in examples["prompt"]]
if not data_args.packing: if not data_args.packing:
if data_args.template == "gemma": if data_args.template == "gemma":

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import bisect import bisect
from typing import TYPE_CHECKING, List, Sequence from typing import TYPE_CHECKING, List, Sequence

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
@@ -7,8 +21,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, gre
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
from ..template import Template from ..template import Template

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.logging import get_logger from ...extras.logging import get_logger
@@ -6,8 +20,7 @@ from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from ...hparams import DataArguments from ...hparams import DataArguments
from ..template import Template from ..template import Template

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
@@ -24,12 +38,12 @@ class Template:
format_observation: "Formatter" format_observation: "Formatter"
format_tools: "Formatter" format_tools: "Formatter"
format_separator: "Formatter" format_separator: "Formatter"
format_prefix: "Formatter"
default_system: str default_system: str
stop_words: List[str] stop_words: List[str]
image_token: str image_token: str
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
force_system: bool
def encode_oneturn( def encode_oneturn(
self, self,
@@ -65,6 +79,12 @@ class Template:
""" """
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
r"""
Extracts tool message.
"""
return self.format_tools.extract(content)
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
@@ -83,10 +103,15 @@ class Template:
encoded_messages = [] encoded_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
elements = [] elements = []
if i == 0 and (system or tools or self.force_system):
if i == 0:
elements += self.format_prefix.apply()
if i == 0 and (system or tools):
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text)) elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply() elements += self.format_separator.apply()
if message["role"] == Role.USER.value: if message["role"] == Role.USER.value:
@@ -173,11 +198,16 @@ class Llama2Template(Template):
encoded_messages = [] encoded_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
elements = [] elements = []
if i == 0:
elements += self.format_prefix.apply()
system_text = "" system_text = ""
if i == 0 and (system or tools or self.force_system): if i == 0 and (system or tools):
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0] system_text = self.format_system.apply(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0:
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply() elements += self.format_separator.apply()
if message["role"] == Role.USER.value: if message["role"] == Role.USER.value:
@@ -208,12 +238,12 @@ def _register_template(
format_observation: Optional["Formatter"] = None, format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None, format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None, format_separator: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None,
default_system: str = "", default_system: str = "",
stop_words: List[str] = [], stop_words: List[str] = [],
image_token: str = "<image>", image_token: str = "<image>",
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
force_system: bool = False,
) -> None: ) -> None:
r""" r"""
Registers a chat template. Registers a chat template.
@@ -245,9 +275,12 @@ def _register_template(
template_class = Llama2Template if name.startswith("llama2") else Template template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"]) default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) default_function_formatter = FunctionFormatter(
slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots
)
default_tool_formatter = ToolFormatter(tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter() default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter()
TEMPLATES[name] = template_class( TEMPLATES[name] = template_class(
format_user=format_user or default_user_formatter, format_user=format_user or default_user_formatter,
format_assistant=format_assistant or default_assistant_formatter, format_assistant=format_assistant or default_assistant_formatter,
@@ -256,12 +289,12 @@ def _register_template(
format_observation=format_observation or format_user or default_user_formatter, format_observation=format_observation or format_user or default_user_formatter,
format_tools=format_tools or default_tool_formatter, format_tools=format_tools or default_tool_formatter,
format_separator=format_separator or default_separator_formatter, format_separator=format_separator or default_separator_formatter,
format_prefix=format_prefix or default_prefix_formatter,
default_system=default_system, default_system=default_system,
stop_words=stop_words, stop_words=stop_words,
image_token=image_token, image_token=image_token,
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
force_system=force_system,
) )
@@ -307,6 +340,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
jinja_template = "" jinja_template = ""
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
if prefix:
jinja_template += "{{ " + prefix + " }}"
if template.default_system: if template.default_system:
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
@@ -315,11 +352,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
) )
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
if isinstance(template, Llama2Template): if not isinstance(template, Llama2Template):
pass
elif template.force_system:
jinja_template += "{{ " + system_message + " }}"
else:
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
jinja_template += "{% for message in messages %}" jinja_template += "{% for message in messages %}"
@@ -435,9 +468,8 @@ _register_template(
_register_template( _register_template(
name="belle", name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]), format_separator=EmptyFormatter(slots=["\n\n"]),
force_system=True, format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
) )
@@ -450,11 +482,7 @@ _register_template(
_register_template( _register_template(
name="breeze", name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"You are a helpful AI assistant built by MediaTek Research. "
"The user you are helping speaks Traditional Chinese and comes from Taiwan."
),
efficient_eos=True, efficient_eos=True,
) )
@@ -462,10 +490,9 @@ _register_template(
_register_template( _register_template(
name="chatglm2", name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n\n"]), format_separator=EmptyFormatter(slots=["\n\n"]),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
efficient_eos=True, efficient_eos=True,
force_system=True,
) )
@@ -473,14 +500,14 @@ _register_template(
name="chatglm3", name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
), ),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
stop_words=["<|user|>", "<|observation|>"], stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True, efficient_eos=True,
force_system=True,
) )
@@ -488,13 +515,12 @@ _register_template(
name="chatglm3_system", name="chatglm3_system",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter( format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
), ),
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
default_system=( default_system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. " "You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown." "Follow the user's instructions carefully. Respond using markdown."
@@ -529,8 +555,7 @@ _register_template(
_register_template( _register_template(
name="codegeex2", name="codegeex2",
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
force_system=True,
) )
@@ -544,21 +569,15 @@ _register_template(
) )
] ]
), ),
format_system=StringFormatter( format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
slots=[{"bos_token"}, "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"] format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
),
default_system=(
"You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users "
"by providing thorough responses. You are trained by Cohere."
),
) )
_register_template( _register_template(
name="cpm", name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]), format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
) )
@@ -591,8 +610,7 @@ _register_template(
_register_template( _register_template(
name="deepseek", name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
) )
@@ -622,11 +640,8 @@ _register_template(
_register_template( _register_template(
name="empty", name="empty",
format_user=StringFormatter(slots=["{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
efficient_eos=True, efficient_eos=True,
force_system=True,
) )
@@ -648,13 +663,12 @@ _register_template(
_register_template( _register_template(
name="gemma", name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"] slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
), ),
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]), format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True, efficient_eos=True,
force_system=True,
) )
@@ -662,36 +676,33 @@ _register_template(
name="glm4", name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]), format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["[gMASK]<sop>{{content}}"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
stop_words=["<|user|>", "<|observation|>"], stop_words=["<|user|>", "<|observation|>"],
efficient_eos=True, efficient_eos=True,
force_system=True,
) )
_register_template( _register_template(
name="intern", name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]), format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]), format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
format_separator=EmptyFormatter(slots=["<eoa>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<eoa>"], stop_words=["<eoa>"],
efficient_eos=True, efficient_eos=True, # internlm tokenizer cannot set eos_token_id
) )
_register_template( _register_template(
name="intern2", name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
default_system=( format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
"You are an AI assistant whose name is InternLM (书生·浦语).\n"
"- InternLM (书生·浦语) is a conversational language model that is developed "
"by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
"- InternLM (书生·浦语) can understand and communicate fluently in the language chosen "
"by the user such as English and 中文."
),
stop_words=["<|im_end|>"], stop_words=["<|im_end|>"],
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
) )
@@ -700,7 +711,6 @@ _register_template(
_register_template( _register_template(
name="llama2", name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_assistant=StringFormatter(slots=[" {{content}} ", {"eos_token"}]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]), format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
) )
@@ -723,9 +733,7 @@ _register_template(
) )
] ]
), ),
format_system=StringFormatter( format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
slots=[{"bos_token"}, "<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]
),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=[ slots=[
( (
@@ -734,7 +742,7 @@ _register_template(
) )
] ]
), ),
default_system="You are a helpful assistant.", format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"], stop_words=["<|eot_id|>"],
replace_eos=True, replace_eos=True,
) )
@@ -743,24 +751,21 @@ _register_template(
_register_template( _register_template(
name="mistral", name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
) )
_register_template( _register_template(
name="olmo", name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"eos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
force_system=True,
) )
_register_template( _register_template(
name="openchat", name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
) )
@@ -774,27 +779,25 @@ _register_template(
) )
] ]
), ),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>"], stop_words=["<|eot_id|>"],
replace_eos=True, replace_eos=True,
force_system=True,
) )
_register_template( _register_template(
name="orion", name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
force_system=True,
) )
_register_template( _register_template(
name="phi", name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|system|>\n{{content}}<|end|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful AI assistant.", format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|end|>"], stop_words=["<|end|>"],
replace_eos=True, replace_eos=True,
) )
@@ -827,7 +830,6 @@ _register_template(
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|end|>"], stop_words=["<|end|>"],
replace_eos=True, replace_eos=True,
force_system=True,
) )

View File

@@ -1,4 +1,41 @@
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py # Copyright 2024 the LlamaFactory team.
#
# This code is inspired by the Dan's test library.
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2020 Dan Hendrycks
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import inspect import inspect
import json import json

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Sequence, Tuple from typing import Dict, List, Sequence, Tuple

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import logging import logging
import os import os

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional
@@ -389,6 +403,18 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat", DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat", DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
}, },
"DeepSeek-MoE-Coder-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
},
"DeepSeek-MoE-Coder-236B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
},
"DeepSeek-MoE-Coder-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
},
"DeepSeek-MoE-Coder-236B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
},
}, },
template="deepseek", template="deepseek",
) )
@@ -668,6 +694,21 @@ register_model_group(
) )
register_model_group(
models={
"MiniCPM-2B-SFT-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
},
"MiniCPM-2B-DPO-Chat": {
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
},
},
template="cpm",
)
register_model_group( register_model_group(
models={ models={
"Mistral-7B-v0.1": { "Mistral-7B-v0.1": {

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform import platform
import accelerate import accelerate
@@ -9,7 +23,7 @@ import trl
from transformers.utils import is_torch_cuda_available, is_torch_npu_available from transformers.utils import is_torch_cuda_available, is_torch_npu_available
VERSION = "0.8.1" VERSION = "0.8.2"
def print_env() -> None: def print_env() -> None:

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import os import os
import sys import sys

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc import gc
import os import os
from typing import TYPE_CHECKING, Dict, Tuple from typing import TYPE_CHECKING, Dict, Tuple
@@ -8,6 +22,7 @@ from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTr
from transformers.utils import ( from transformers.utils import (
SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
is_safetensors_available,
is_torch_bf16_gpu_available, is_torch_bf16_gpu_available,
is_torch_cuda_available, is_torch_cuda_available,
is_torch_mps_available, is_torch_mps_available,
@@ -20,6 +35,11 @@ from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger from .logging import get_logger
if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import save_file
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try: try:
_is_bf16_available = is_torch_bf16_gpu_available() _is_bf16_available = is_torch_bf16_gpu_available()
@@ -114,9 +134,6 @@ def fix_valuehead_checkpoint(
return return
if safe_serialization: if safe_serialization:
from safetensors import safe_open
from safetensors.torch import save_file
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}

View File

@@ -1,5 +1,23 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata import importlib.metadata
import importlib.util import importlib.util
from functools import lru_cache
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from packaging import version from packaging import version
@@ -24,10 +42,6 @@ def is_fastapi_available():
return _is_package_available("fastapi") return _is_package_available("fastapi")
def is_flash_attn2_available():
return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0")
def is_galore_available(): def is_galore_available():
return _is_package_available("galore_torch") return _is_package_available("galore_torch")
@@ -36,18 +50,10 @@ def is_gradio_available():
return _is_package_available("gradio") return _is_package_available("gradio")
def is_jieba_available():
return _is_package_available("jieba")
def is_matplotlib_available(): def is_matplotlib_available():
return _is_package_available("matplotlib") return _is_package_available("matplotlib")
def is_nltk_available():
return _is_package_available("nltk")
def is_pillow_available(): def is_pillow_available():
return _is_package_available("PIL") return _is_package_available("PIL")
@@ -60,10 +66,6 @@ def is_rouge_available():
return _is_package_available("rouge_chinese") return _is_package_available("rouge_chinese")
def is_sdpa_available():
return _get_package_version("torch") > version.parse("2.1.1")
def is_starlette_available(): def is_starlette_available():
return _is_package_available("sse_starlette") return _is_package_available("sse_starlette")
@@ -74,3 +76,8 @@ def is_uvicorn_available():
def is_vllm_available(): def is_vllm_available():
return _is_package_available("vllm") return _is_package_available("vllm")
@lru_cache
def is_vllm_version_greater_than_0_5():
return _get_package_version("vllm") >= version.parse("0.5.0")

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import math import math
import os import os

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments

View File

@@ -1,3 +1,20 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional from typing import Literal, Optional

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional from typing import Literal, Optional

View File

@@ -1,5 +1,19 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional from typing import List, Literal, Optional
@dataclass @dataclass
@@ -94,6 +108,18 @@ class LoraArguments:
default=False, default=False,
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}, metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
) )
pissa_init: bool = field(
default=False,
metadata={"help": "Whether or not to initialize a PiSSA adapter."},
)
pissa_iter: int = field(
default=4,
metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
)
pissa_convert: bool = field(
default=False,
metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
)
create_new_adapter: bool = field( create_new_adapter: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
@@ -319,20 +345,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
return [item.strip() for item in arg.split(",")] return [item.strip() for item in arg.split(",")]
return arg return arg
self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules) self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
self.freeze_extra_modules = split_arg(self.freeze_extra_modules) self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
self.lora_alpha = self.lora_alpha or self.lora_rank * 2 self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target = split_arg(self.lora_target) self.lora_target: List[str] = split_arg(self.lora_target)
self.additional_target = split_arg(self.additional_target) self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target = split_arg(self.galore_target) self.galore_target: List[str] = split_arg(self.galore_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
if self.stage == "ppo" and self.reward_model is None: if self.stage == "ppo" and self.reward_model is None:
raise ValueError("`reward_model` is necessary for PPO training.") raise ValueError("`reward_model` is necessary for PPO training.")
@@ -354,5 +379,11 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora": if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.") raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
if self.pissa_convert and self.finetuning_type != "lora":
raise ValueError("`pissa_convert` is only valid for LoRA training.")
if self.pissa_convert and (self.stage in ["rm", "ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.")
if self.train_mm_proj_only and self.finetuning_type != "full": if self.train_mm_proj_only and self.finetuning_type != "full":
raise ValueError("`train_mm_proj_only` is only valid for full training.") raise ValueError("`train_mm_proj_only` is only valid for full training.")

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Optional from typing import Any, Dict, Optional

View File

@@ -1,5 +1,28 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
from typing_extensions import Self
if TYPE_CHECKING:
import torch
@dataclass @dataclass
@@ -22,6 +45,10 @@ class ModelArguments:
) )
}, },
) )
adapter_folder: Optional[str] = field(
default=None,
metadata={"help": "The folder containing the adapter weights to load."},
)
cache_dir: Optional[str] = field( cache_dir: Optional[str] = field(
default=None, default=None,
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
@@ -127,13 +154,9 @@ class ModelArguments:
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."}, metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
) )
vllm_max_lora_rank: int = field( vllm_max_lora_rank: int = field(
default=8, default=32,
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."}, metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
) )
vllm_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
default="auto",
metadata={"help": "Data type for model weights and activations in the vLLM engine."},
)
offload_folder: str = field( offload_folder: str = field(
default="offload", default="offload",
metadata={"help": "Path to offload model weights."}, metadata={"help": "Path to offload model weights."},
@@ -142,6 +165,10 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether or not to use KV cache in generation."}, metadata={"help": "Whether or not to use KV cache in generation."},
) )
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
default="auto",
metadata={"help": "Data type for model weights and activations at inference."},
)
hf_hub_token: Optional[str] = field( hf_hub_token: Optional[str] = field(
default=None, default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}, metadata={"help": "Auth token to log in with Hugging Face Hub."},
@@ -192,9 +219,9 @@ class ModelArguments:
) )
def __post_init__(self): def __post_init__(self):
self.compute_dtype = None self.compute_dtype: Optional["torch.dtype"] = None
self.device_map = None self.device_map: Optional[Union[str, Dict[str, Any]]] = None
self.model_max_length = None self.model_max_length: Optional[int] = None
if self.split_special_tokens and self.use_fast_tokenizer: if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
@@ -216,3 +243,13 @@ class ModelArguments:
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return asdict(self) return asdict(self)
@classmethod
def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
arg_dict = old_arg.to_dict()
arg_dict.update(**kwargs)
new_arg = cls(**arg_dict)
new_arg.compute_dtype = old_arg.compute_dtype
new_arg.device_map = old_arg.device_map
new_arg.model_max_length = old_arg.model_max_length
return new_arg

View File

@@ -1,3 +1,20 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
import os import os
import sys import sys
@@ -8,6 +25,7 @@ import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
@@ -72,6 +90,9 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.") raise ValueError("Quantization is only compatible with the LoRA method.")
if finetuning_args.pissa_init:
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
if model_args.resize_vocab: if model_args.resize_vocab:
raise ValueError("Cannot resize embedding layers of a quantized model.") raise ValueError("Cannot resize embedding layers of a quantized model.")
@@ -162,6 +183,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
): ):
raise ValueError("PPO only accepts wandb or tensorboard logger.") raise ValueError("PPO only accepts wandb or tensorboard logger.")
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
if training_args.max_steps == -1 and data_args.streaming: if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.") raise ValueError("Please specify `max_steps` in streaming mode.")
@@ -171,9 +195,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and model_args.quantization_device_map == "auto": if training_args.do_train and model_args.quantization_device_map == "auto":
raise ValueError("Cannot use device map for quantized models in training.") raise ValueError("Cannot use device map for quantized models in training.")
if finetuning_args.use_dora and model_args.use_unsloth:
raise ValueError("Unsloth does not support DoRA.")
if finetuning_args.pure_bf16: if finetuning_args.pure_bf16:
if not is_torch_bf16_gpu_available(): if not is_torch_bf16_gpu_available():
raise ValueError("This device does not support `pure_bf16`.") raise ValueError("This device does not support `pure_bf16`.")
@@ -184,14 +205,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if ( if (
finetuning_args.use_galore finetuning_args.use_galore
and finetuning_args.galore_layerwise and finetuning_args.galore_layerwise
and training_args.parallel_mode.value == "distributed" and training_args.parallel_mode == ParallelMode.DISTRIBUTED
): ):
raise ValueError("Distributed training does not support layer-wise GaLore.") raise ValueError("Distributed training does not support layer-wise GaLore.")
if ( if (
finetuning_args.use_badam finetuning_args.use_badam
and finetuning_args.badam_mode == "layer" and finetuning_args.badam_mode == "layer"
and training_args.parallel_mode.value == "distributed" and training_args.parallel_mode == ParallelMode.DISTRIBUTED
): ):
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.") raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
@@ -233,7 +254,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
# Post-process training arguments # Post-process training arguments
if ( if (
training_args.parallel_mode.value == "distributed" training_args.parallel_mode == ParallelMode.DISTRIBUTED
and training_args.ddp_find_unused_parameters is None and training_args.ddp_find_unused_parameters is None
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"
): ):
@@ -293,7 +314,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
training_args.local_rank, training_args.local_rank,
training_args.device, training_args.device,
training_args.n_gpu, training_args.n_gpu,
training_args.parallel_mode.value == "distributed", training_args.parallel_mode == ParallelMode.DISTRIBUTED,
str(model_args.compute_dtype), str(model_args.compute_dtype),
) )
) )
@@ -332,6 +353,7 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
if model_args.export_dir is not None and model_args.export_device == "cpu": if model_args.export_dir is not None and model_args.export_device == "cpu":
model_args.device_map = {"": torch.device("cpu")} model_args.device_map = {"": torch.device("cpu")}
model_args.model_max_length = data_args.cutoff_len
else: else:
model_args.device_map = "auto" model_args.device_map = "auto"

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from llamafactory.train.tuner import run_exp from llamafactory.train.tuner import run_exp

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .loader import load_config, load_model, load_tokenizer from .loader import load_config, load_model, load_tokenizer
from .model_utils.misc import find_all_linear_modules from .model_utils.misc import find_all_linear_modules
from .model_utils.valuehead import load_valuehead_params from .model_utils.valuehead import load_valuehead_params

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@@ -25,8 +39,12 @@ def _setup_full_tuning(
model: "PreTrainedModel", model: "PreTrainedModel",
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool, cast_trainable_params_to_fp32: bool,
) -> None: ) -> None:
if not is_trainable:
return
logger.info("Fine-tuning method: Full") logger.info("Fine-tuning method: Full")
forbidden_modules = set() forbidden_modules = set()
if model_args.visual_inputs and finetuning_args.freeze_vision_tower: if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
@@ -47,8 +65,12 @@ def _setup_freeze_tuning(
model: "PreTrainedModel", model: "PreTrainedModel",
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: bool,
cast_trainable_params_to_fp32: bool, cast_trainable_params_to_fp32: bool,
) -> None: ) -> None:
if not is_trainable:
return
logger.info("Fine-tuning method: Freeze") logger.info("Fine-tuning method: Freeze")
if model_args.visual_inputs: if model_args.visual_inputs:
config = model.config.text_config config = model.config.text_config
@@ -132,7 +154,9 @@ def _setup_lora_tuning(
is_trainable: bool, is_trainable: bool,
cast_trainable_params_to_fp32: bool, cast_trainable_params_to_fp32: bool,
) -> "PeftModel": ) -> "PeftModel":
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) if is_trainable:
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
adapter_to_resume = None adapter_to_resume = None
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
@@ -155,8 +179,16 @@ def _setup_lora_tuning(
else: else:
adapter_to_merge = model_args.adapter_name_or_path adapter_to_merge = model_args.adapter_name_or_path
init_kwargs = {
"subfolder": model_args.adapter_folder,
"offload_folder": model_args.offload_folder,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"token": model_args.hf_hub_token,
}
for adapter in adapter_to_merge: for adapter in adapter_to_merge:
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, offload_folder=model_args.offload_folder) model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs)
model = model.merge_and_unload() model = model.merge_and_unload()
if len(adapter_to_merge) > 0: if len(adapter_to_merge) > 0:
@@ -166,12 +198,9 @@ def _setup_lora_tuning(
if model_args.use_unsloth: if model_args.use_unsloth:
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable)
else: else:
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
model,
adapter_to_resume, logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
is_trainable=is_trainable,
offload_folder=model_args.offload_folder,
)
if is_trainable and adapter_to_resume is None: # create new lora weights while training if is_trainable and adapter_to_resume is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
@@ -216,6 +245,14 @@ def _setup_lora_tuning(
if model_args.use_unsloth: if model_args.use_unsloth:
model = get_unsloth_peft_model(model, model_args, peft_kwargs) model = get_unsloth_peft_model(model, model_args, peft_kwargs)
else: else:
if finetuning_args.pissa_init:
if finetuning_args.pissa_iter == -1:
logger.info("Using PiSSA initialization.")
peft_kwargs["init_lora_weights"] = "pissa"
else:
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter))
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter)
lora_config = LoraConfig( lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
inference_mode=False, inference_mode=False,
@@ -227,9 +264,6 @@ def _setup_lora_tuning(
for param in filter(lambda p: p.requires_grad, model.parameters()): for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
if model_args.adapter_name_or_path is not None:
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
return model return model
@@ -247,29 +281,37 @@ def init_adapter(
Note that the trainable parameters must be cast to float32. Note that the trainable parameters must be cast to float32.
""" """
if (not is_trainable) and model_args.adapter_name_or_path is None: if is_trainable and getattr(model, "quantization_method", None) is not None:
logger.info("Adapter is not found at evaluation, load the base model.") if finetuning_args.finetuning_type != "lora":
return model raise ValueError("Quantized models can only be used for the LoRA tuning.")
if finetuning_args.finetuning_type != "lora" and getattr(model, "quantization_method", None): if finetuning_args.pissa_init:
raise ValueError("You can only use lora for quantized models.") raise ValueError("Cannot initialize PiSSA adapter on quantized models.")
if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam: # cast trainable parameters to float32 if:
# 1. is_trainable and quantization_bit is not None (qlora)
# 2. is_trainable and not deepspeed zero3 and not fsdp (zero3 or fsdp already in float32)
# 3. is_trainable and not pure_bf16 and not badam
if not is_trainable:
cast_trainable_params_to_fp32 = False
elif model_args.quantization_bit is None and (
is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam
):
logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.") logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.")
cast_trainable_params_to_fp32 = False cast_trainable_params_to_fp32 = False
else: else:
logger.info("Upcasting trainable params to float32.") logger.info("Upcasting trainable params to float32.")
cast_trainable_params_to_fp32 = True cast_trainable_params_to_fp32 = True
if is_trainable and finetuning_args.finetuning_type == "full": if finetuning_args.finetuning_type == "full":
_setup_full_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32) _setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
elif finetuning_args.finetuning_type == "freeze":
if is_trainable and finetuning_args.finetuning_type == "freeze": _setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32)
_setup_freeze_tuning(model, model_args, finetuning_args, cast_trainable_params_to_fp32) elif finetuning_args.finetuning_type == "lora":
if finetuning_args.finetuning_type == "lora":
model = _setup_lora_tuning( model = _setup_lora_tuning(
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32
) )
else:
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type))
return model return model

View File

@@ -1,3 +1,17 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer

View File

@@ -1,7 +1,22 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -21,13 +36,13 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
requested_attn_implementation = "eager" requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa": elif model_args.flash_attn == "sdpa":
if not is_sdpa_available(): if not is_torch_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.") logger.warning("torch>=2.1.1 is required for SDPA attention.")
return return
requested_attn_implementation = "sdpa" requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2": elif model_args.flash_attn == "fa2":
if not is_flash_attn2_available(): if not is_flash_attn_2_available():
logger.warning("FlashAttention-2 is not installed.") logger.warning("FlashAttention-2 is not installed.")
return return

View File

@@ -1,3 +1,21 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers and PEFT library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect import inspect
from functools import partial from functools import partial
from types import MethodType from types import MethodType
@@ -68,7 +86,6 @@ def prepare_model_for_training(
(1) cast the layernorm in fp32 (1) cast the layernorm in fp32
(2) make output embedding layer require grads (2) make output embedding layer require grads
(3) add the upcasting of the lm_head in fp32 (3) add the upcasting of the lm_head in fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
""" """
if model_args.upcast_layernorm: if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.") logger.info("Upcasting layernorm weights in float32.")

Some files were not shown because too many files have changed in this diff Show More