79 Commits

Author SHA1 Message Date
hiyouga
debfd46749 release v0.5.2
Former-commit-id: 0189867816b0eab92fb2a1b5f1b1da079bd161a7
2024-02-20 11:12:43 +08:00
hiyouga
5ccf8fcd6b update webui
Former-commit-id: 9e0f7c362d40b78d57e77d52eaa96e678cebadcd
2024-02-19 16:49:58 +08:00
hiyouga
7bd1991513 add test scripts
Former-commit-id: fdaa4843961257b48cc32d83d30f2efe18b9fd5a
2024-02-19 02:09:13 +08:00
hiyouga
456e4ca569 fix safetensors
Former-commit-id: 06478ae5302d5fc6eb7afedc69335ce2f32808c6
2024-02-18 18:12:16 +08:00
hiyouga
6bf0fe4913 fix #2481
Former-commit-id: 2a4e3e4a26a2fad77ccc476be7d45434b8af4a55
2024-02-15 19:07:47 +08:00
hiyouga
596b6828cb support llama pro #2338 , add rslora
Former-commit-id: 40d659b7f30dd5a004703c176ec1f22dc864e505
2024-02-15 02:27:36 +08:00
hoshi-hiyouga
b403f8d8a8 Merge pull request #2474 from younesbelkada/add-hf-tags
FEAT: add HF tags for models that have been trained with llama-factory
Former-commit-id: f35d96817e61da9fa7789b93b0350c9f95afc40a
2024-02-14 10:26:03 +08:00
younesbelkada
590b6c2143 add v1 hf tags
Former-commit-id: a29cc9f4472c95cd6a43ea350ab728e0a8069c6e
2024-02-13 05:58:49 +00:00
hiyouga
5537ef1e7d fix #2471
Former-commit-id: a408be8be1cf99cd4468a9905c27ec454f312b9a
2024-02-12 21:07:46 +08:00
hiyouga
5f83860aa1 add option to disable version check
Former-commit-id: fd769cb2de696aee3c5e882237e16eace6a9d675
2024-02-10 22:31:23 +08:00
hiyouga
62b6a7971a update data/readme
Former-commit-id: aa566e3cea5bc75688b4399a9da07be0b35b921c
2024-02-10 21:04:29 +08:00
hiyouga
1d16e87c5f update default template
Former-commit-id: f32b55649a9f95109a6d180216eb67f959d060da
2024-02-10 16:44:47 +08:00
hiyouga
1955a8ea5a improve aligner
Former-commit-id: cc7296b92e10c24967fc753393275b71d300683f
2024-02-10 16:39:19 +08:00
hoshi-hiyouga
a41fa6e730 Merge pull request #2462 from mnmueller/main
Enable Parsing of SlimOrca

Former-commit-id: 99eed520b87152ca6b89c2a068b09200fd45f30d
2024-02-09 22:55:48 +08:00
hiyouga
b98a64448a improve fix tokenizer
Former-commit-id: 57b138abad6397596bc47be94e092e8fabedc06f
2024-02-09 14:53:14 +08:00
Mark Mueller
1ce82f391a Slim Orca data parsing
Former-commit-id: f2d8efede7e20edafed0d5446eb64f2d419949b1
2024-02-08 19:32:20 +01:00
Mark Mueller
4d473894fd Slim Orca data parsing
Former-commit-id: ca57d27c39d4e7bc3dd7c3207a23d23d2cbd446b
2024-02-08 17:56:18 +01:00
Mark Mueller
5788b7c7d0 Slim Orca data parsing
Former-commit-id: 3016427be4e63fd25f40bc5a0d1f8cedc0997334
2024-02-08 17:54:18 +01:00
Mark Mueller
04515f6b55 Slim Orca data parsing
Former-commit-id: 4dca3907964d27abc2b21eb55c75676901c98912
2024-02-08 17:52:36 +01:00
Mark Mueller
96f8ccf3d5 SlimOrca aligner
Former-commit-id: 928dda93867c2327a7957c04648592044ccf9daf
2024-02-08 08:28:32 -08:00
hoshi-hiyouga
2c3ef480a6 Merge pull request #2423 from mayflower/main
Support for german sft and dpo

Former-commit-id: 8e282e4e6bee6493b1bd38ba239ca49a6a840a92
2024-02-07 15:58:20 +08:00
hiyouga
fa6873122c Update tests.yml
Former-commit-id: c882b7cf339304ff16a36b1544a3b5f1194ef346
2024-02-07 01:18:22 +08:00
hiyouga
34bc0c22b1 lint
Former-commit-id: 6b1f89b6494e9b6b087fe90600617a3024e014e5
2024-02-07 01:10:04 +08:00
hiyouga
e5484b2729 Update pyproject.toml
Former-commit-id: 650251ea77fae2e2595ca804f49efdd230dbb5b1
2024-02-07 00:45:58 +08:00
hiyouga
f67f781fed update gc kwargs
Former-commit-id: 0cb81c156bc8c21a4bbdd3289a491f78dfcaf730
2024-02-07 00:38:24 +08:00
hiyouga
b564b97b7e fix #2438
Former-commit-id: 412d856eeada2abcea598fac0a8d35ae90cc9c01
2024-02-06 15:23:08 +08:00
hiyouga
0dd68d1e06 add models
Former-commit-id: 0fdf61b2f765c125acda4f406eb25b3e59e75db2
2024-02-06 14:57:23 +08:00
hiyouga
73f40f1ca4 support qwen1.5
Former-commit-id: 8a03a572b058c5cc4ff598670dc8595b2b97e374
2024-02-06 00:10:51 +08:00
hoshi-hiyouga
ea53bebac4 fix #2436
Update test_toolcall.py

Former-commit-id: 39c539b6470c532ac639efbd2a1c485d2f5d485f
2024-02-05 22:55:28 +08:00
hoshi-hiyouga
00418012bd Update test_toolcall.py
Former-commit-id: f50a684a9d6fc2351436d3d7020dc84bc1553a5d
2024-02-05 22:51:03 +08:00
hoshi-hiyouga
5f3d8c514b Update test_toolcall.py
Former-commit-id: 97bcae546ab80737a906e5e28953f41b657f6c99
2024-02-05 22:50:43 +08:00
tao.jun
cb39a3f1c4 Update test_toolcall.py
Add openai version notes

Former-commit-id: 9ea4ab214e64f73ec902e76b82fc42419571fd66
2024-02-05 20:49:23 +08:00
Johann-Peter Hartmann
4d78fe6ece Merge branch 'hiyouga:main' into main
Former-commit-id: efbb0153981d0650f3a581e324b83054ca8063c1
2024-02-04 13:55:00 +00:00
hiyouga
a3e3ea9846 fix #2421
Former-commit-id: 43918c12310f7560d3820e5c6d72964309afeb8b
2024-02-04 21:02:55 +08:00
Johann-Peter Hartmann
feba34e82d Merge branch 'hiyouga:main' into main
Former-commit-id: 0395d0aafb69e86645e6b0a36b8f8dadb82219e0
2024-02-04 12:51:25 +00:00
hiyouga
e134013e04 fix reserved label len
Former-commit-id: b06d6c05a1911f329252a7572240048e456affdc
2024-02-04 17:54:26 +08:00
hiyouga
5589d0296a fix #2420
Former-commit-id: 7a34087e4db62e603c9a9a26d8ff3910d7b10c40
2024-02-04 15:51:47 +08:00
hiyouga
de0ebab464 fix #2189
Former-commit-id: b3d81b229d376671e1c12229aeb487b0d84f2548
2024-02-04 00:47:37 +08:00
hiyouga
f2e7122a96 bump up transformers version
Former-commit-id: 82f4d4301ed9f31b160d6313a1d2d44a22865f4d
2024-02-04 00:01:16 +08:00
hiyouga
996cc5d900 fix #2397
Former-commit-id: 7404692808f2288d539668d364965ad104dacadb
2024-02-03 23:45:31 +08:00
hiyouga
a2ae5bd867 add hint for freeze #2412
Former-commit-id: 9600c93633629605573d908019563fa3870ad6f8
2024-02-03 23:38:56 +08:00
hiyouga
5fa52e87cb fix #2376
Former-commit-id: 8e2cfa7cca21b7fd4538d72114e36f704bcc82fe
2024-02-03 23:14:31 +08:00
hiyouga
bcd76d2c7a support minicpm #2404
Former-commit-id: 4449e91cbee8fd804cf8bf1ff6b9f5301fde94ed
2024-02-03 22:36:46 +08:00
Johann-Peter Hartmann
36fcbedc11 add simple german chatml template chatml_de
Former-commit-id: 9f1d67c09f1af2c7aa383adec66842cacde92e33
2024-02-03 09:01:15 +01:00
Johann-Peter Hartmann
1dad01cc53 Merge branch 'hiyouga:main' into main
Former-commit-id: c350237d891df7edd7e681f9da5ac1446fdeb568
2024-02-03 08:43:12 +01:00
hoshi-hiyouga
5fb21f6e54 Merge pull request #2411 from lxsyz/main
fix eos_token_id=0 bug

Former-commit-id: 019a353e74ec70a9a2d8987df1ed19483413211a
2024-02-02 17:38:16 +08:00
Fallen Angel
08dfac8352 fix eos_token_id=0 bug
when eos_token_id=0, will never add eos_token

Former-commit-id: 576b4881c386d897462a875b28066ce9d6e06dd5
2024-02-02 17:34:48 +08:00
Johann-Peter Hartmann
956751e419 Merge branch 'hiyouga:main' into main
Former-commit-id: 25b0a11c715f87812edba1ca14d3122a75f421de
2024-01-31 14:05:52 +01:00
hiyouga
fe2ae04c91 fix #2388
Former-commit-id: 203a36c9adfd9aa0f35fbf8089c9138534d68c53
2024-01-31 17:23:56 +08:00
hiyouga
5b8712d061 fix autoset attn impl, update data readme
Former-commit-id: 34a6e5f82baf45cc8dbb11f9f7ab4a480ab7ec5c
2024-01-31 11:58:07 +08:00
Johann-Peter Hartmann
dc7ff90c1e Add support for german datasets
Former-commit-id: bbc038aa236952597e97d1ccf1ae2d64a16339b5
2024-01-30 10:18:01 +01:00
hiyouga
1ace676170 fix #2320
Former-commit-id: e0b0c4415aaf80e75f6dd4f3777a0616b0e60f84
2024-01-24 16:19:18 +08:00
hoshi-hiyouga
8947a87b95 Merge pull request #2319 from ftgreat/main
Add patch_mixtral_replace_moe_impl for full training Mitral using DeepSpeed Zero3

Former-commit-id: 0fadcd5f9deb9f03d341b6611c15f337f07e32d1
2024-01-24 15:32:26 +08:00
ldwang
786a2f1103 Add patch_mixtral_replace_moe_impl for full training Mitral using DeepSpeed Zero3.
Signed-off-by: ldwang <ftgreat@gmail.com>

Former-commit-id: 5f50c02f0e425737cd80abdf8fde9e25abf13083
2024-01-24 15:25:31 +08:00
ldwang
36ac14a566 Add patch_mixtral_replace_moe_impl for full training Mitral using DeepSpeed Zero3.
Signed-off-by: ldwang <ftgreat@gmail.com>

Former-commit-id: d1413dcec8a3b1d671f240b82a689c72b54d7b93
2024-01-24 14:43:16 +08:00
hiyouga
7a048fc91d add hint
Former-commit-id: c540ef41bda61993b83ef8cfe3c84b1d169e984c
2024-01-22 23:32:01 +08:00
hoshi-hiyouga
3f3756b113 Merge pull request #2283 from A-Cepheus/main
fix: ZeRO3 does not work with MoE models
Former-commit-id: f5ea760abec2aac8d29ce5c945647be05648e676
2024-01-22 23:28:45 +08:00
hoshi-hiyouga
b36c4b99cc Update patcher.py
Former-commit-id: 33556cc6b0b65cc6db02e66f4f6e75112c33d966
2024-01-22 23:27:39 +08:00
hoshi-hiyouga
9856a2276e Update tests.yml
Former-commit-id: 34151675388701afa40220729a63b0d7b2fa2a7c
2024-01-22 23:22:15 +08:00
hoshi-hiyouga
b6dc3ed3ad Create tests.yml
Former-commit-id: 9443ad76b7ef3ef1f3d184ef60652947d2c30609
2024-01-22 23:13:04 +08:00
hiyouga
75be329994 fix #2282 and update tool prompt
Former-commit-id: 1c412f803866bde32b76f7c26c7b464b6b3651f3
2024-01-22 22:27:30 +08:00
hiyouga
1fe1ca1c8b add orion models
Former-commit-id: a34db89d2a281d1a1ace29dfd5bd5d4ff7c2f657
2024-01-22 21:26:53 +08:00
A-Cepheus
882a6a1d51 🐞 fix: typo
Former-commit-id: 57a3687ecd23237559aee0e8e811b782846f2415
2024-01-22 16:04:39 +08:00
A-Cepheus
712ab4ae7a 🐞 fix: typo, move MoE fix to patcher
Former-commit-id: 4ff28e99ff9b48df7150591c6bbd3723f22b7715
2024-01-22 16:01:58 +08:00
A-Cepheus
18ad259fb3 fix: ZeRO3 does not work with MoE models
Former-commit-id: b2844c049a88ea89f8e1812e2d2e8662b4002965
2024-01-22 15:21:14 +08:00
hiyouga
fe4d93c6db add array param format
Former-commit-id: bf910f8a5b21ee552fa9ab069610a3f5f611de57
2024-01-21 22:17:48 +08:00
hiyouga
c6ba588e37 update tool test
Former-commit-id: 1d63ccc2866632596310235de15fdff660f6bee5
2024-01-21 19:41:46 +08:00
hiyouga
3fda60fca0 fix api
Former-commit-id: cca004da28aaaa0788eaea62b83d3402b38a3011
2024-01-21 19:15:27 +08:00
hiyouga
96531a0ef8 fix #2268
Former-commit-id: 300ecf9b9d7fd99fbb68f3d086e3ad973c2f894e
2024-01-21 14:11:38 +08:00
hiyouga
7abc3065fb tiny fix
Former-commit-id: 66839ae94985ddfa13eca4542127119c919b9648
2024-01-21 13:26:12 +08:00
hoshi-hiyouga
013ded4bac Merge pull request #2266 from yhyu13/fix_export_model_dtype
Remove manully set use_cache; torch_dtype is not str, save model as b…

Former-commit-id: 8c0827ba92a458e18c3b68af0330af3a65149f96
2024-01-21 12:40:39 +08:00
hoshi-hiyouga
010c3c7348 Merge branch 'main' into fix_export_model_dtype
Former-commit-id: 6c7d2729f28eb37a97820d73c05648eb7fb2ca87
2024-01-21 12:40:24 +08:00
hoshi-hiyouga
bf075c075c Update tuner.py
Former-commit-id: 691420661f7115f809e76484c1f29f74637e7cd0
2024-01-21 12:39:38 +08:00
hoshi-hiyouga
41b34e5f60 Merge pull request #2262 from fenglui/main
fix torch_dtype check of export_model

Former-commit-id: 37cacf73a534fed1b06b4f3c6724f3568ce095e3
2024-01-21 12:34:37 +08:00
hiyouga
5a889398e7 format
Former-commit-id: f28a1a0c1cdd0062ad7b6c2826f8ec107a200cff
2024-01-21 12:34:17 +08:00
hoshi-hiyouga
054cae86d8 Merge pull request #2264 from seoeaa/main
add russian lang

Former-commit-id: 15d1747de54efe69ed9f4cfd8f296fe8dd09a5c9
2024-01-21 12:25:24 +08:00
yhyu13
cd1cb8b83c Remove manully set use_cache; torch_dtype is not str, save model as bfloat16 used to fail;
Former-commit-id: 75557fb5df16fd6eda7586cf041a16822dcfee8e
2024-01-21 11:12:15 +08:00
Aleksandr
a34779c027 add russian lang
Former-commit-id: f8ce6d75b56439027bb17ff4e59eeb9eb3b9bd34
2024-01-21 04:28:14 +03:00
fenglui
d19cb77d74 fix torch_dtype check of export_model
Former-commit-id: 8813181b6bffa76e5c7cb1f4caceada611c54b9d
2024-01-21 05:01:53 +08:00
50 changed files with 2378 additions and 601 deletions

29
.github/workflows/tests.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: tests
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
jobs:
check_code_quality:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install black ruff
- name: Check quality
run: |
make style && make quality

View File

@@ -55,14 +55,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `tests/llama_pro.py` for usage.
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`. [24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
<details><summary>Full Changelog</summary>
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details. [23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement). [23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
<details><summary>Full Changelog</summary>
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage. [23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`. [23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neftune_noise_alpha` argument to activate NEFTune, e.g., `--neftune_noise_alpha 5`.
@@ -110,6 +114,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral | | [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi | | [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan | | [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
@@ -154,8 +159,8 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Self-cognition (zh)](data/self_cognition.json) - [Self Cognition (zh)](data/self_cognition.json)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
@@ -171,8 +176,10 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca) - [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
@@ -185,6 +192,15 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) - [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2) - [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
</details> </details>
@@ -194,6 +210,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details> </details>

View File

@@ -55,14 +55,18 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
## 更新日志 ## 更新日志
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `tests/llama_pro.py`
[24/02/05] Qwen1.5Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。 [24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
<details><summary>展开日志</summary>
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。 [23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。 [23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
<details><summary>展开日志</summary>
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。 [23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune例如 `--neftune_noise_alpha 5` [23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neftune_noise_alpha` 参数启用 NEFTune例如 `--neftune_noise_alpha 5`
@@ -110,6 +114,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
| [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral | | [Mixtral](https://huggingface.co/mistralai) | 8x7B | q_proj,v_proj | mistral |
| [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - | | [Phi-1.5/2](https://huggingface.co/microsoft) | 1.3B/2.7B | q_proj,v_proj | - |
| [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen | | [Qwen](https://huggingface.co/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
| [Qwen1.5](https://huggingface.co/Qwen) | 0.5B/1.8B/4B/7B/14B/72B | q_proj,v_proj | qwen |
| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse | | [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | q_proj,v_proj | xverse |
| [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi | | [Yi](https://huggingface.co/01-ai) | 6B/34B | q_proj,v_proj | yi |
| [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan | | [Yuan](https://huggingface.co/IEITYuan) | 2B/51B/102B | q_proj,v_proj | yuan |
@@ -154,8 +159,8 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [Alpaca GPT4 (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Self-cognition (zh)](data/self_cognition.json) - [Self Cognition (zh)](data/self_cognition.json)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
@@ -171,8 +176,10 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca) - [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) - [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
@@ -185,6 +192,15 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m) - [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) - [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
- [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2) - [Glaive Function Calling V2 (en)](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2)
- [Open Assistant (de)](https://huggingface.co/datasets/mayflowergmbh/oasst_de)
- [Dolly 15k (de)](https://huggingface.co/datasets/mayflowergmbh/dolly-15k_de)
- [Alpaca GPT4 (de)](https://huggingface.co/datasets/mayflowergmbh/alpaca-gpt4_de)
- [OpenSchnabeltier (de)](https://huggingface.co/datasets/mayflowergmbh/openschnabeltier_de)
- [Evol Instruct (de)](https://huggingface.co/datasets/mayflowergmbh/evol-instruct_de)
- [Dolphin (de)](https://huggingface.co/datasets/mayflowergmbh/dolphin_de)
- [Booksum (de)](https://huggingface.co/datasets/mayflowergmbh/booksum_de)
- [Airoboros (de)](https://huggingface.co/datasets/mayflowergmbh/airoboros-3.0_de)
- [Ultrachat (de)](https://huggingface.co/datasets/mayflowergmbh/ultra-chat_de)
</details> </details>
@@ -194,6 +210,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar) - [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
- [Orca DPO (de)](https://huggingface.co/datasets/mayflowergmbh/intel_orca_dpo_pairs_de)
</details> </details>

View File

@@ -11,15 +11,23 @@ If you are using a custom dataset, please provide your dataset definition in the
"folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)", "folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
"ranking": "whether the dataset is a preference dataset or not. (default: false)", "ranking": "whether the dataset is a preference dataset or not. (default: false)",
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})", "formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
"columns": { "columns (optional)": {
"prompt": "the column name in the dataset containing the prompts. (default: instruction, for alpaca)", "prompt": "the column name in the dataset containing the prompts. (default: instruction)",
"query": "the column name in the dataset containing the queries. (default: input, for alpaca)", "query": "the column name in the dataset containing the queries. (default: input)",
"response": "the column name in the dataset containing the responses. (default: output, for alpaca)", "response": "the column name in the dataset containing the responses. (default: output)",
"history": "the column name in the dataset containing the histories. (default: None, for alpaca)", "history": "the column name in the dataset containing the histories. (default: None)",
"messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)", "messages": "the column name in the dataset containing the messages. (default: conversations)",
"role": "the key in the message represents the identity. (default: from, for sharegpt)", "system": "the column name in the dataset containing the system prompts. (default: None)",
"content": "the key in the message represents the content. (default: value, for sharegpt)", "tools": "the column name in the dataset containing the tool description. (default: None)"
"system": "the column name in the dataset containing the system prompts. (default: None, for both)" },
"tags (optional, used for the sharegpt format)": {
"role_tag": "the key in the message represents the identity. (default: from)",
"content_tag": "the key in the message represents the content. (default: value)",
"user_tag": "the value of the role_tag represents the user. (default: human)",
"assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)",
"observation_tag": "the value of the role_tag represents the tool results. (default: observation)",
"function_tag": "the value of the role_tag represents the function call. (default: function_call)",
"system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)"
} }
} }
``` ```
@@ -57,9 +65,9 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
} }
``` ```
where the `prompt` and `response` columns should contain non-empty values, represent instruction and response respectively. The `query` column will be concatenated with the `prompt` column and used as input for the model. The `query` column will be concatenated with the `prompt` column and used as the user prompt, then the user prompt would be `prompt\nquery`. The `response` column represents the model response.
The `system` column will be used as the system prompt in the template. The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**. The `system` column will be used as the system prompt. The `history` column is a list consisting string tuples representing prompt-response pairs in the history. Note that the responses in the history **will also be used for training**.
For the pre-training datasets, only the `prompt` column will be used for training. For the pre-training datasets, only the `prompt` column will be used for training.
@@ -91,7 +99,8 @@ The dataset in sharegpt format should follow the below format:
"value": "model response" "value": "model response"
} }
], ],
"system": "system prompt (optional)" "system": "system prompt (optional)",
"tools": "tool description (optional)"
} }
] ]
``` ```
@@ -102,13 +111,18 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
"dataset_name": { "dataset_name": {
"columns": { "columns": {
"messages": "conversations", "messages": "conversations",
"role": "from", "system": "system",
"content": "value", "tools": "tools"
"system": "system" },
"tags": {
"role_tag": "from",
"content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
} }
} }
``` ```
where the `messages` column should be a list whose length is even, and follow the `u/a/u/a/u/a` order. where the `messages` column should be a list following the `u/a/u/a/u/a` order.
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet. Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.

View File

@@ -11,15 +11,23 @@
"folder": "Hugging Face 仓库的文件夹名称可选默认None", "folder": "Hugging Face 仓库的文件夹名称可选默认None",
"ranking": "是否为偏好数据集可选默认False", "ranking": "是否为偏好数据集可选默认False",
"formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt", "formatting": "数据集格式可选默认alpaca可以为 alpaca 或 sharegpt",
"columns": { "columns(可选)": {
"prompt": "数据集代表提示词的表头名称默认instruction,用于 alpaca 格式", "prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input,用于 alpaca 格式", "query": "数据集代表请求的表头名称默认input",
"response": "数据集代表回答的表头名称默认output,用于 alpaca 格式", "response": "数据集代表回答的表头名称默认output",
"history": "数据集代表历史对话的表头名称默认None,用于 alpaca 格式", "history": "数据集代表历史对话的表头名称默认None",
"messages": "数据集代表消息列表的表头名称默认conversations,用于 sharegpt 格式", "messages": "数据集代表消息列表的表头名称默认conversations",
"role": "消息中代表发送者身份的键名默认from用于 sharegpt 格式", "system": "数据集代表系统提示的表头名称默认None",
"content": "消息中代表文本内容的键名默认value用于 sharegpt 格式)", "tools": "数据集代表工具描述的表头名称默认None"
"system": "数据集代表系统提示的表头名称默认None用于两种格式" },
"tags可选用于 sharegpt 格式)": {
"role_tag": "消息中代表发送者身份的键名默认from",
"content_tag": "消息中代表文本内容的键名默认value",
"user_tag": "消息中代表用户的 role_tag默认human",
"assistant_tag": "消息中代表助手的 role_tag默认gpt",
"observation_tag": "消息中代表工具返回结果的 role_tag默认observation",
"function_tag": "消息中代表工具调用的 role_tag默认function_call",
"system_tag": "消息中代表系统提示的 role_tag默认system会覆盖 system 列)"
} }
} }
``` ```
@@ -57,9 +65,9 @@
} }
``` ```
其中 `prompt``response` 列应当是非空的字符串,分别代表用户指令和模型回答。`query` 列的内容将会和 `prompt` 列拼接作为模型输入 其中 `query` 列对应的内容会与 `prompt` 列对应的内容拼接后作为用户指令,即用户指令为 `prompt\nquery``response` 列对应的内容为模型回答
`system`模板中的系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**会被用于训练**。 `system` 列对应的内容将被作为系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意历史消息中的回答**会被用于训练**。
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。 对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
@@ -91,7 +99,8 @@
"value": "模型回答" "value": "模型回答"
} }
], ],
"system": "系统提示词(选填)" "system": "系统提示词(选填)",
"tools": "工具描述(选填)"
} }
] ]
``` ```
@@ -102,13 +111,18 @@
"数据集名称": { "数据集名称": {
"columns": { "columns": {
"messages": "conversations", "messages": "conversations",
"role": "from", "system": "system",
"content": "value", "tools": "tools"
"system": "system" },
"tags": {
"role_tag": "from",
"content_tag": "value",
"user_tag": "human",
"assistant_tag": "gpt"
} }
} }
``` ```
其中 `messages`必须为偶数长度的列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。 其中 `messages`应当是一个列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
预训练数据集和偏好数据集尚不支持 sharegpt 格式。 预训练数据集和偏好数据集尚不支持 sharegpt 格式。

View File

@@ -1 +1 @@
fc9a6a3458caca2af8dafc6181773fe10c6d8657 34c723573fbc2d7601f6d9c882ccf5aa4f9bcc4b

View File

@@ -7,14 +7,23 @@ line-length = 119
target-version = ["py38"] target-version = ["py38"]
[tool.ruff] [tool.ruff]
line-length = 119
indent-width = 4
[tool.ruff.lint]
ignore = ["C408", "C901", "E501", "E731", "E741", "W605"] ignore = ["C408", "C901", "E501", "E731", "E741", "W605"]
select = ["C", "E", "F", "I", "W"] select = ["C", "E", "F", "I", "W"]
line-length = 119
[tool.ruff.isort] [tool.ruff.lint.isort]
lines-after-imports = 2 lines-after-imports = 2
known-first-party = ["llmtuner"] known-first-party = ["llmtuner"]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
[isort] [isort]
default_section = "FIRSTPARTY" default_section = "FIRSTPARTY"
known_first_party = "llmtuner" known_first_party = "llmtuner"

View File

@@ -1,8 +1,8 @@
torch>=1.13.1 torch>=1.13.1
transformers>=4.36.2 transformers>=4.37.2
datasets>=2.14.3 datasets>=2.14.3
accelerate>=0.21.0 accelerate>=0.21.0
peft>=0.7.0 peft>=0.8.2
trl>=0.7.6 trl>=0.7.6
gradio>=3.38.0,<4.0.0 gradio>=3.38.0,<4.0.0
scipy scipy

View File

@@ -7,5 +7,5 @@ from .train import export_model, run_exp
from .webui import create_ui, create_web_demo from .webui import create_ui, create_web_demo
__version__ = "0.5.0" __version__ = "0.5.2"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"] __all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@@ -74,6 +74,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
) )
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1))) semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
role_mapping = {
Role.USER: DataRole.USER,
Role.ASSISTANT: DataRole.ASSISTANT,
Role.SYSTEM: DataRole.SYSTEM,
Role.FUNCTION: DataRole.FUNCTION,
Role.TOOL: DataRole.OBSERVATION,
}
@app.get("/v1/models", response_model=ModelList) @app.get("/v1/models", response_model=ModelList)
async def list_models(): async def list_models():
@@ -85,30 +92,29 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if not chat_model.can_generate: if not chat_model.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0 or request.messages[-1].role != Role.USER: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
messages = [dictify(message) for message in request.messages] if role_mapping[request.messages[0].role] == DataRole.SYSTEM:
if len(messages) and messages[0]["role"] == Role.SYSTEM: system = request.messages.pop(0).content
system = messages.pop(0)["content"]
else: else:
system = None system = ""
if len(messages) % 2 == 0: if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
for i in range(len(messages)): input_messages = []
if i % 2 == 0 and messages[i]["role"] not in [Role.USER, Role.TOOL]: for i, message in enumerate(request.messages):
input_messages.append({"role": role_mapping[message.role], "content": message.content})
if i % 2 == 0 and input_messages[i]["role"] not in [DataRole.USER, DataRole.OBSERVATION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and messages[i]["role"] not in [Role.ASSISTANT, Role.FUNCTION]: elif i % 2 == 1 and input_messages[i]["role"] not in [DataRole.ASSISTANT, DataRole.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif messages[i]["role"] == Role.TOOL:
messages[i]["role"] = DataRole.OBSERVATION
tool_list = request.tools tool_list = request.tools
if len(tool_list): if isinstance(tool_list, list) and len(tool_list):
try: try:
tools = json.dumps([tool_list[0]["function"]], ensure_ascii=False) tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False)
except Exception: except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else: else:
@@ -116,10 +122,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
async with semaphore: async with semaphore:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, chat_completion, messages, system, tools, request) return await loop.run_in_executor(None, chat_completion, input_messages, system, tools, request)
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest): def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
if request.stream: if request.stream:
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
generate = stream_chat_completion(messages, system, tools, request) generate = stream_chat_completion(messages, system, tools, request)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")

View File

@@ -28,7 +28,7 @@ class ChatModel:
) )
self.tokenizer.padding_side = "left" if self.can_generate else "right" self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
def _process_args( def _process_args(
self, self,
@@ -94,6 +94,9 @@ class ChatModel:
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs, **input_kwargs,
) -> List[Response]: ) -> List[Response]:
if not self.can_generate:
raise ValueError("The current model does not support `chat`.")
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs) gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs) generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
@@ -123,6 +126,9 @@ class ChatModel:
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
if not self.can_generate:
raise ValueError("The current model does not support `stream_chat`.")
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs) gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
@@ -134,9 +140,11 @@ class ChatModel:
@torch.inference_mode() @torch.inference_mode()
def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]: def get_scores(self, batch_input: List[str], **input_kwargs) -> List[float]:
if self.can_generate:
raise ValueError("Cannot get scores using an auto-regressive model.")
max_length = input_kwargs.pop("max_length", None) max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda") device = getattr(self.model.pretrained_model, "device", "cuda")
inputs = self.tokenizer( inputs = self.tokenizer(
batch_input, batch_input,
padding=True, padding=True,

View File

@@ -1,6 +1,8 @@
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features
from .utils import Role from .utils import Role
@@ -15,23 +17,24 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
outputs = {"prompt": [], "response": [], "system": [], "tools": []} outputs = {"prompt": [], "response": [], "system": [], "tools": []}
for i in range(len(examples[dataset_attr.prompt])): for i in range(len(examples[dataset_attr.prompt])):
prompt = [] prompt = []
if dataset_attr.history: if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
for old_prompt, old_response in examples[dataset_attr.history][i]: for old_prompt, old_response in examples[dataset_attr.history][i]:
prompt.append({"role": Role.USER, "content": old_prompt}) prompt.append({"role": Role.USER, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT, "content": old_response}) prompt.append({"role": Role.ASSISTANT, "content": old_response})
instruction = examples[dataset_attr.prompt][i] content = []
if dataset_attr.query and examples[dataset_attr.query][i]: if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
instruction += "\n" + examples[dataset_attr.query][i] content.append(examples[dataset_attr.prompt][i])
prompt.append({"role": Role.USER, "content": instruction})
if dataset_attr.response: if dataset_attr.query and examples[dataset_attr.query][i]:
if isinstance(examples[dataset_attr.response][i], list): content.append(examples[dataset_attr.query][i])
response = [
{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i] prompt.append({"role": Role.USER, "content": "\n".join(content)})
]
else: if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}] response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
else: else:
response = [] response = []
@@ -50,32 +53,34 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
dataset_attr.assistant_tag: Role.ASSISTANT, dataset_attr.assistant_tag: Role.ASSISTANT,
dataset_attr.observation_tag: Role.OBSERVATION, dataset_attr.observation_tag: Role.OBSERVATION,
dataset_attr.function_tag: Role.FUNCTION, dataset_attr.function_tag: Role.FUNCTION,
dataset_attr.system_tag: Role.SYSTEM,
} }
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
accept_tags = (odd_tags, even_tags)
for i, messages in enumerate(examples[dataset_attr.messages]): for i, messages in enumerate(examples[dataset_attr.messages]):
if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
system = messages[0][dataset_attr.content_tag]
messages = messages[1:]
else:
system = examples[dataset_attr.system][i] if dataset_attr.system else ""
messages = messages[: len(messages) // 2 * 2] # should be multiples of 2 messages = messages[: len(messages) // 2 * 2] # should be multiples of 2
if len(messages) == 0: if len(messages) == 0:
continue continue
prompt = [] aligned_messages = []
response = []
for turn_idx, message in enumerate(messages): for turn_idx, message in enumerate(messages):
if turn_idx % 2 == 0: if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag]
else:
accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag]
if message[dataset_attr.role_tag] not in accept_tags:
raise ValueError("Invalid role tag in {}.".format(messages)) raise ValueError("Invalid role tag in {}.".format(messages))
prompt.append( aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
) )
last_message = prompt.pop(-1) outputs["prompt"].append(aligned_messages[:-1])
response.append(last_message) outputs["response"].append(aligned_messages[-1:])
outputs["prompt"].append(prompt) outputs["system"].append(system)
outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
return outputs return outputs
@@ -86,8 +91,8 @@ def align_dataset(
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r""" r"""
Aligned dataset: Aligned dataset:
prompt: [{"role": "user", "content": "..."}] prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..." system: "..."
tools: "..." tools: "..."
""" """
@@ -97,6 +102,18 @@ def align_dataset(
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr) convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
{
"prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
}
)
kwargs = {} kwargs = {}
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
@@ -105,4 +122,10 @@ def align_dataset(
desc="Converting format of dataset", desc="Converting format of dataset",
) )
return dataset.map(convert_func, batched=True, remove_columns=column_names, **kwargs) return dataset.map(
convert_func,
batched=True,
remove_columns=column_names,
features=features,
**kwargs,
)

View File

@@ -15,11 +15,11 @@ JSON_FORMAT_PROMPT = (
TOOL_SYSTEM_PROMPT = ( TOOL_SYSTEM_PROMPT = (
"You have access to the following tools:\n{tool_text}" "You have access to the following tools:\n{tool_text}"
"Use the following format to answer the question:\n" "Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [{tool_names}]).\n"
"Action Input: the input to the tool{format_prompt}.\n"
"```\n" "```\n"
"Action: the action to take, should be one of [{tool_names}] if using a tool.\n"
"Action Input: the input to the action{format_prompt}.\n"
"```"
) )
@@ -31,12 +31,16 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
for name, param in tool["parameters"]["properties"].items(): for name, param in tool["parameters"]["properties"].items():
required = ", required" if name in tool["parameters"].get("required", []) else "" required = ", required" if name in tool["parameters"].get("required", []) else ""
enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else "" enum = ", should be one of [{}]".format(", ".join(param["enum"])) if param.get("enum", None) else ""
param_text += " - {name} ({type}{required}): {desc}{enum}\n".format( items = (
", where each item should be {}".format(param["items"].get("type", "")) if param.get("items") else ""
)
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
name=name, name=name,
type=param.get("type", ""), type=param.get("type", ""),
required=required, required=required,
desc=param.get("description", ""), desc=param.get("description", ""),
enum=enum, enum=enum,
items=items,
) )
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format( tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
@@ -71,8 +75,7 @@ class Formatter(ABC):
tool_format: Literal["default"] = "default" tool_format: Literal["default"] = "default"
@abstractmethod @abstractmethod
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS: ...
...
def extract(self, content: str) -> Union[str, Tuple[str, str]]: def extract(self, content: str) -> Union[str, Tuple[str, str]]:
raise NotImplementedError raise NotImplementedError
@@ -91,12 +94,15 @@ class StringFormatter(Formatter):
for slot in self.slots: for slot in self.slots:
if isinstance(slot, str): if isinstance(slot, str):
for name, value in kwargs.items(): for name, value in kwargs.items():
if not isinstance(value, str):
raise RuntimeError("Expected a string, got {}".format(value))
slot = slot.replace("{{" + name + "}}", value, 1) slot = slot.replace("{{" + name + "}}", value, 1)
elements.append(slot) elements.append(slot)
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements return elements
@@ -120,7 +126,7 @@ class FunctionFormatter(Formatter):
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements return elements

View File

@@ -30,6 +30,7 @@ def load_single_dataset(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
): ):
logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
data_path = dataset_attr.dataset_name data_path = dataset_attr.dataset_name
@@ -60,7 +61,7 @@ def load_single_dataset(
if data_path is None: if data_path is None:
raise ValueError("File extension must be txt, csv, json or jsonl.") raise ValueError("File extension must be txt, csv, json or jsonl.")
checksum(data_files, dataset_attr.dataset_sha1) checksum(data_files, dataset_attr.file_sha1)
else: else:
raise NotImplementedError raise NotImplementedError
@@ -142,7 +143,7 @@ def get_dataset(
stage: Literal["pt", "sft", "rm", "ppo"], stage: Literal["pt", "sft", "rm", "ppo"],
# split: Optional[str] = "train", # TODO: add split # split: Optional[str] = "train", # TODO: add split
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
template = get_template_and_fix_tokenizer(data_args.template, tokenizer) template = get_template_and_fix_tokenizer(tokenizer, data_args.template)
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")
@@ -155,12 +156,9 @@ def get_dataset(
dataset = dataset.to_iterable_dataset() dataset = dataset.to_iterable_dataset()
return dataset return dataset
if data_args.streaming:
raise ValueError("Turn off dataset streaming to save cache files.")
with training_args.main_process_first(desc="load dataset"): with training_args.main_process_first(desc="load dataset"):
all_datasets = [] all_datasets = []
for dataset_attr in get_dataset_list(data_args): # TODO: add split for dataset_attr in get_dataset_list(data_args):
all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args)) all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
dataset = merge_dataset(all_datasets, data_args, training_args) dataset = merge_dataset(all_datasets, data_args, training_args)
@@ -188,6 +186,6 @@ def get_dataset(
try: try:
print_function(next(iter(dataset))) print_function(next(iter(dataset)))
except StopIteration: except StopIteration:
raise RuntimeError("Empty dataset!") raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
return dataset return dataset

View File

@@ -1,7 +1,7 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Literal, Optional from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from ..extras.constants import DATA_CONFIG from ..extras.constants import DATA_CONFIG
from ..extras.misc import use_modelscope from ..extras.misc import use_modelscope
@@ -13,34 +13,44 @@ if TYPE_CHECKING:
@dataclass @dataclass
class DatasetAttr: class DatasetAttr:
r"""
Dataset attributes.
"""
""" basic configs """
load_from: Literal["hf_hub", "ms_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None """ extra configs """
file_sha1: Optional[str] = None
subset: Optional[str] = None subset: Optional[str] = None
folder: Optional[str] = None folder: Optional[str] = None
ranking: Optional[bool] = False ranking: Optional[bool] = False
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
""" columns """
system: Optional[str] = None system: Optional[str] = None
""" columns for the alpaca format """
prompt: Optional[str] = "instruction" prompt: Optional[str] = "instruction"
query: Optional[str] = "input" query: Optional[str] = "input"
response: Optional[str] = "output" response: Optional[str] = "output"
history: Optional[str] = None history: Optional[str] = None
""" columns for the sharegpt format """
messages: Optional[str] = "conversations" messages: Optional[str] = "conversations"
tools: Optional[str] = None tools: Optional[str] = None
""" tags for the sharegpt format """
role_tag: Optional[str] = "from" role_tag: Optional[str] = "from"
content_tag: Optional[str] = "value" content_tag: Optional[str] = "value"
user_tag: Optional[str] = "human" user_tag: Optional[str] = "human"
assistant_tag: Optional[str] = "gpt" assistant_tag: Optional[str] = "gpt"
observation_tag: Optional[str] = "observation" observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call" function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system"
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
setattr(self, key, obj.get(key, default))
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else [] dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
@@ -73,30 +83,36 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else: else:
dataset_attr = DatasetAttr( dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
"file",
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None),
)
dataset_attr.subset = dataset_info[name].get("subset", None) dataset_attr.set_attr("file_sha1", dataset_info[name])
dataset_attr.folder = dataset_info[name].get("folder", None) dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.ranking = dataset_info[name].get("ranking", False) dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca") dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
column_names = ["system"]
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
column_names = ["prompt", "query", "response", "history"] column_names.extend(["prompt", "query", "response", "history"])
else: else:
column_names = ["messages", "tools"] column_names.extend(["messages", "tools"])
column_names += ["system"]
for column_name in column_names: for column_name in column_names:
setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None)) dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag"]: tag_names = (
setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None)) "role_tag",
"content_tag",
"user_tag",
"assistant_tag",
"observation_tag",
"function_tag",
"system_tag",
)
for tag in tag_names:
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
dataset_list.append(dataset_attr) dataset_list.append(dataset_attr)

View File

@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Tuple
from ..extras.constants import IGNORE_INDEX from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .utils import Role
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -21,12 +22,8 @@ def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` # build grouped texts with format `X1 X2 X3 ...`
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))] text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
tokenized_examples = tokenizer(text_examples, add_special_tokens=False) tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
for i in range(len(tokenized_examples["input_ids"])):
tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id]
tokenized_examples["attention_mask"][i] += [1]
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len block_size = data_args.cutoff_len
@@ -51,14 +48,19 @@ def preprocess_supervised_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1: if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue continue
messages = examples["prompt"][i] + examples["response"][i] messages = examples["prompt"][i] + examples["response"][i]
input_ids, labels = [], [] input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate( for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn( template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
) )
): ):
if data_args.train_on_prompt: if data_args.train_on_prompt:
@@ -93,7 +95,7 @@ def preprocess_packed_supervised_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
input_ids, labels = [], [] input_ids, labels = [], []
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1: if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
continue continue
messages = examples["prompt"][i] + examples["response"][i] messages = examples["prompt"][i] + examples["response"][i]
@@ -137,12 +139,21 @@ def preprocess_unsupervised_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) != 1: if len(examples["prompt"][i]) % 2 != 1:
continue continue
messages = examples["prompt"][i] + examples["response"][i] if len(examples["response"][i]) == 1:
messages = examples["prompt"][i] + examples["response"][i]
else:
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
input_ids, labels = template.encode_oneturn( input_ids, labels = template.encode_oneturn(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
) )
if template.efficient_eos: if template.efficient_eos:
@@ -164,17 +175,27 @@ def preprocess_pairwise_dataset(
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) == 0 or len(examples["response"][i]) < 2: if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
continue continue
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]] chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]] rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn( prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len tokenizer,
chosen_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
) )
_, rejected_ids = template.encode_oneturn( _, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len tokenizer,
rejected_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
) )
if template.efficient_eos: if template.efficient_eos:

View File

@@ -37,7 +37,7 @@ class Template:
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000, cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 16, reserved_label_len: Optional[int] = 1,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
r""" r"""
Returns a single pair of token ids representing prompt and response respectively. Returns a single pair of token ids representing prompt and response respectively.
@@ -57,7 +57,7 @@ class Template:
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000, cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 16, reserved_label_len: Optional[int] = 1,
) -> Sequence[Tuple[List[int], List[int]]]: ) -> Sequence[Tuple[List[int], List[int]]]:
r""" r"""
Returns multiple pairs of token ids representing prompts and responses respectively. Returns multiple pairs of token ids representing prompts and responses respectively.
@@ -117,9 +117,9 @@ class Template:
elif isinstance(elem, dict): elif isinstance(elem, dict):
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))] token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
elif isinstance(elem, set): elif isinstance(elem, set):
if "bos_token" in elem and tokenizer.bos_token_id: if "bos_token" in elem and tokenizer.bos_token_id is not None:
token_ids += [tokenizer.bos_token_id] token_ids += [tokenizer.bos_token_id]
elif "eos_token" in elem and tokenizer.eos_token_id: elif "eos_token" in elem and tokenizer.eos_token_id is not None:
token_ids += [tokenizer.eos_token_id] token_ids += [tokenizer.eos_token_id]
else: else:
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
@@ -144,10 +144,10 @@ class Template:
max_len=(cutoff_len - total_length), max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len, reserved_label_len=reserved_label_len,
) )
encoded_messages[i] = encoded_messages[i][:max_source_len] source_ids = encoded_messages[i][:max_source_len]
encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len] target_ids = encoded_messages[i + 1][:max_target_len]
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1]) total_length += len(source_ids) + len(target_ids)
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1])) encoded_pairs.append((source_ids, target_ids))
return encoded_pairs return encoded_pairs
@@ -198,7 +198,7 @@ class Llama2Template(Template):
templates: Dict[str, Template] = {} templates: Dict[str, Template] = {}
def register_template( def _register_template(
name: str, name: str,
format_user: Optional["Formatter"] = None, format_user: Optional["Formatter"] = None,
format_assistant: Optional["Formatter"] = None, format_assistant: Optional["Formatter"] = None,
@@ -218,7 +218,7 @@ def register_template(
default_user_formatter = StringFormatter(slots=["{{content}}"]) default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(slots="default") default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter() default_separator_formatter = EmptyFormatter()
templates[name] = template_class( templates[name] = template_class(
format_user=format_user or default_user_formatter, format_user=format_user or default_user_formatter,
@@ -236,29 +236,45 @@ def register_template(
) )
def get_template_and_fix_tokenizer(name: str, tokenizer: "PreTrainedTokenizer") -> Template: def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
if tokenizer.eos_token_id is None: is_added = tokenizer.eos_token_id is None
tokenizer.eos_token = "<|endoftext|>" is_oov = eos_token not in tokenizer.get_vocab()
tokenizer.add_special_tokens({"eos_token": eos_token})
if is_added:
logger.info("Add eos token: {}".format(tokenizer.eos_token)) logger.info("Add eos token: {}".format(tokenizer.eos_token))
else:
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None: if is_oov:
tokenizer.pad_token = tokenizer.eos_token logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if name is None: # for pre-training
return None
template = templates.get(name, None) def get_template_and_fix_tokenizer(
assert template is not None, "Template {} does not exist.".format(name) tokenizer: "PreTrainedTokenizer",
name: Optional[str] = None,
) -> Template:
if name is None:
template = templates["vanilla"] # placeholder
else:
template = templates.get(name, None)
if templates is None:
raise ValueError("Template {} does not exist.".format(name))
stop_words = template.stop_words stop_words = template.stop_words
if template.replace_eos: if template.replace_eos:
if not stop_words: if not stop_words:
raise ValueError("Stop words are required to replace the EOS token.") raise ValueError("Stop words are required to replace the EOS token.")
tokenizer.eos_token = stop_words[0] _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
stop_words = stop_words[1:] stop_words = stop_words[1:]
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if tokenizer.eos_token_id is None:
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
if stop_words: if stop_words:
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
@@ -269,7 +285,7 @@ def get_template_and_fix_tokenizer(name: str, tokenizer: "PreTrainedTokenizer")
return template return template
register_template( _register_template(
name="alpaca", name="alpaca",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]), format_separator=EmptyFormatter(slots=["\n\n"]),
@@ -279,7 +295,7 @@ register_template(
) )
register_template( _register_template(
name="aquila", name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
format_separator=EmptyFormatter(slots=["###"]), format_separator=EmptyFormatter(slots=["###"]),
@@ -292,21 +308,21 @@ register_template(
) )
register_template( _register_template(
name="baichuan", name="baichuan",
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]), format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True, efficient_eos=True,
) )
register_template( _register_template(
name="baichuan2", name="baichuan2",
format_user=StringFormatter(slots=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]), format_user=StringFormatter(slots=[{"token": "<reserved_106>"}, "{{content}}", {"token": "<reserved_107>"}]),
efficient_eos=True, efficient_eos=True,
) )
register_template( _register_template(
name="belle", name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
@@ -315,13 +331,13 @@ register_template(
) )
register_template( _register_template(
name="bluelm", name="bluelm",
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
) )
register_template( _register_template(
name="chatglm2", name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
@@ -331,7 +347,7 @@ register_template(
) )
register_template( _register_template(
name="chatglm3", name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
@@ -339,7 +355,9 @@ register_template(
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"] slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
), ),
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]), format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
format_observation=StringFormatter(slots=[{"token": "<|observation|>"}, "\n", "{{content}}"]), format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
),
default_system=( default_system=(
"You are ChatGLM3, a large language model trained by Zhipu.AI. " "You are ChatGLM3, a large language model trained by Zhipu.AI. "
"Follow the user's instructions carefully. Respond using markdown." "Follow the user's instructions carefully. Respond using markdown."
@@ -349,14 +367,33 @@ register_template(
) )
register_template( _register_template(
name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>"],
replace_eos=True,
)
_register_template(
name="codegeex2", name="codegeex2",
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
force_system=True, force_system=True,
) )
register_template( _register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="deepseek", name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
@@ -364,9 +401,10 @@ register_template(
) )
register_template( _register_template(
name="deepseekcoder", name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:\n"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_separator=EmptyFormatter(slots=["\n", {"token": "<|EOT|>"}, "\n"]), format_separator=EmptyFormatter(slots=["\n", {"token": "<|EOT|>"}, "\n"]),
default_system=( default_system=(
"You are an AI programming assistant, utilizing the Deepseek Coder model, " "You are an AI programming assistant, utilizing the Deepseek Coder model, "
@@ -379,14 +417,15 @@ register_template(
) )
register_template( _register_template(
name="default", name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]), format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
format_system=StringFormatter(slots=["{{content}}\n"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
) )
register_template( _register_template(
name="falcon", name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
@@ -394,7 +433,7 @@ register_template(
) )
register_template( _register_template(
name="intern", name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]), format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]), format_separator=EmptyFormatter(slots=[{"token": "<eoa>"}, "\n"]),
@@ -403,7 +442,7 @@ register_template(
) )
register_template( _register_template(
name="intern2", name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=[{"bos_token"}, "<|im_start|>system\n{{content}}<|im_end|>\n"]),
@@ -420,7 +459,7 @@ register_template(
) )
register_template( _register_template(
name="llama2", name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]), format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
@@ -437,7 +476,7 @@ register_template(
) )
register_template( _register_template(
name="llama2_zh", name="llama2_zh",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]), format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
@@ -445,7 +484,7 @@ register_template(
) )
register_template( _register_template(
name="mistral", name="mistral",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
@@ -453,7 +492,7 @@ register_template(
) )
register_template( _register_template(
name="openchat", name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
format_assistant=StringFormatter(slots=["{{content}}"]), format_assistant=StringFormatter(slots=["{{content}}"]),
@@ -462,7 +501,15 @@ register_template(
) )
register_template( _register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
_register_template(
name="qwen", name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
@@ -473,7 +520,7 @@ register_template(
) )
register_template( _register_template(
name="solar", name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]), format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]), format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
@@ -481,7 +528,7 @@ register_template(
) )
register_template( _register_template(
name="starchat", name="starchat",
format_user=StringFormatter( format_user=StringFormatter(
slots=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}] slots=[{"token": "<|user|>"}, "\n{{content}}", {"token": "<|end|>"}, "\n", {"token": "<|assistant|>"}]
@@ -494,10 +541,12 @@ register_template(
) )
register_template(name="vanilla") _register_template(
name="vanilla",
)
register_template( _register_template(
name="vicuna", name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
default_system=( default_system=(
@@ -507,7 +556,7 @@ register_template(
) )
register_template( _register_template(
name="xuanyuan", name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
default_system=( default_system=(
@@ -518,10 +567,13 @@ register_template(
) )
register_template(name="xverse", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "])) _register_template(
name="xverse",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
)
register_template( _register_template(
name="yayi", name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]), format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
@@ -541,7 +593,7 @@ register_template(
) )
register_template( _register_template(
name="yi", name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
@@ -550,7 +602,7 @@ register_template(
) )
register_template( _register_template(
name="yuan", name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]), format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
@@ -559,7 +611,7 @@ register_template(
) )
register_template( _register_template(
name="zephyr", name="zephyr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
@@ -567,7 +619,7 @@ register_template(
) )
register_template( _register_template(
name="ziya", name="ziya",
format_user=StringFormatter(slots=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]), format_user=StringFormatter(slots=[{"token": "<human>"}, ":{{content}}\n", {"token": "<bot>"}, ":"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),

View File

@@ -19,8 +19,9 @@ logger = get_logger(__name__)
class Role(str, Enum): class Role(str, Enum):
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
OBSERVATION = "observation" SYSTEM = "system"
FUNCTION = "function" FUNCTION = "function"
OBSERVATION = "observation"
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:

View File

@@ -24,7 +24,7 @@ class Evaluator:
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args) self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer) self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
self.eval_template = get_eval_template(self.eval_args.lang) self.eval_template = get_eval_template(self.eval_args.lang)
self.choice_inputs = [ self.choice_inputs = [
self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES self.tokenizer.encode(self.eval_template.prefix + ch, add_special_tokens=False)[-1] for ch in CHOICES

View File

@@ -11,7 +11,14 @@ DEFAULT_MODULE = defaultdict(str)
DEFAULT_TEMPLATE = defaultdict(str) DEFAULT_TEMPLATE = defaultdict(str)
FILEEXT2TYPE = {"arrow": "arrow", "csv": "csv", "json": "json", "jsonl": "json", "parquet": "parquet", "txt": "text"} FILEEXT2TYPE = {
"arrow": "arrow",
"csv": "csv",
"json": "json",
"jsonl": "json",
"parquet": "parquet",
"txt": "text",
}
IGNORE_INDEX = -100 IGNORE_INDEX = -100
@@ -46,7 +53,9 @@ class DownloadSource(str, Enum):
def register_model_group( def register_model_group(
models: Dict[str, Dict[DownloadSource, str]], module: Optional[str] = None, template: Optional[str] = None models: Dict[str, Dict[DownloadSource, str]],
module: Optional[str] = None,
template: Optional[str] = None,
) -> None: ) -> None:
prefix = None prefix = None
for name, path in models.items(): for name, path in models.items():
@@ -219,22 +228,36 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"DeepSeekLLM-7B-Base": { "DeepSeek-LLM-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
}, },
"DeepSeekLLM-67B-Base": { "DeepSeek-LLM-67B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
}, },
"DeepSeekLLM-7B-Chat": { "DeepSeek-LLM-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat", DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
}, },
"DeepSeekLLM-67B-Chat": { "DeepSeek-LLM-67B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat", DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
}, },
"DeepSeek-Math-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
},
"DeepSeek-Math-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
},
"DeepSeek-MoE-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
},
"DeepSeek-MoE-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
},
}, },
template="deepseek", template="deepseek",
) )
@@ -246,6 +269,9 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
}, },
"DeepSeekCoder-7B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
},
"DeepSeekCoder-33B-Base": { "DeepSeekCoder-33B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
@@ -254,6 +280,9 @@ register_model_group(
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
}, },
"DeepSeekCoder-7B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
},
"DeepSeekCoder-33B-Chat": { "DeepSeekCoder-33B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct", DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct", DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
@@ -263,21 +292,6 @@ register_model_group(
) )
register_model_group(
models={
"DeepSeekMoE-16B-Base": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
},
"DeepSeekMoE-16B-Chat": {
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
},
},
template="deepseek",
)
register_model_group( register_model_group(
models={ models={
"Falcon-7B": { "Falcon-7B": {
@@ -370,7 +384,10 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"LLaMA-7B": {DownloadSource.DEFAULT: "huggyllama/llama-7b", DownloadSource.MODELSCOPE: "skyline2006/llama-7b"}, "LLaMA-7B": {
DownloadSource.DEFAULT: "huggyllama/llama-7b",
DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
},
"LLaMA-13B": { "LLaMA-13B": {
DownloadSource.DEFAULT: "huggyllama/llama-13b", DownloadSource.DEFAULT: "huggyllama/llama-13b",
DownloadSource.MODELSCOPE: "skyline2006/llama-13b", DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
@@ -455,7 +472,7 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"OpenChat3.5-7B-Chat": { "OpenChat3.5-7B-Chat": {
DownloadSource.DEFAULT: "openchat/openchat_3.5", DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5", DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5",
} }
}, },
@@ -465,18 +482,63 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Phi-1.5-1.3B": {DownloadSource.DEFAULT: "microsoft/phi-1_5", DownloadSource.MODELSCOPE: "allspace/PHI_1-5"}, "Orion-14B-Base": {
"Phi-2-2.7B": {DownloadSource.DEFAULT: "microsoft/phi-2", DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2"}, DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
},
"Orion-14B-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
},
"Orion-14B-Long-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
},
"Orion-14B-RAG-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
},
"Orion-14B-Plugin-Chat": {
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
},
},
template="orion",
)
register_model_group(
models={
"Phi-1.5-1.3B": {
DownloadSource.DEFAULT: "microsoft/phi-1_5",
DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
},
"Phi-2-2.7B": {
DownloadSource.DEFAULT: "microsoft/phi-2",
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
},
} }
) )
register_model_group( register_model_group(
models={ models={
"Qwen-1.8B": {DownloadSource.DEFAULT: "Qwen/Qwen-1_8B", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"}, "Qwen-1.8B": {
"Qwen-7B": {DownloadSource.DEFAULT: "Qwen/Qwen-7B", DownloadSource.MODELSCOPE: "qwen/Qwen-7B"}, DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
"Qwen-14B": {DownloadSource.DEFAULT: "Qwen/Qwen-14B", DownloadSource.MODELSCOPE: "qwen/Qwen-14B"}, DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B",
"Qwen-72B": {DownloadSource.DEFAULT: "Qwen/Qwen-72B", DownloadSource.MODELSCOPE: "qwen/Qwen-72B"}, },
"Qwen-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen-7B",
},
"Qwen-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen-14B",
},
"Qwen-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen-72B",
},
"Qwen-1.8B-Chat": { "Qwen-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat", DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
@@ -530,7 +592,112 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"SOLAR-10.7B": {DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0"}, "Qwen1.5-0.5B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B",
},
"Qwen1.5-1.8B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B",
},
"Qwen1.5-4B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B",
},
"Qwen1.5-7B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B",
},
"Qwen1.5-14B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
},
"Qwen1.5-72B": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
},
"Qwen1.5-0.5B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
},
"Qwen1.5-1.8B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat",
},
"Qwen1.5-4B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat",
},
"Qwen1.5-7B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat",
},
"Qwen1.5-14B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
},
"Qwen1.5-72B-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
},
"Qwen1.5-0.5B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
},
"Qwen1.5-0.5B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4",
},
"Qwen1.5-1.8B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
},
"Qwen1.5-1.8B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4",
},
"Qwen1.5-4B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
},
"Qwen1.5-4B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int4",
},
"Qwen1.5-7B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
},
"Qwen1.5-7B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int4",
},
"Qwen1.5-14B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
},
"Qwen1.5-14B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int4",
},
"Qwen1.5-72B-int8-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
},
"Qwen1.5-72B-int4-Chat": {
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int4",
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int4",
},
},
template="qwen",
)
register_model_group(
models={
"SOLAR-10.7B": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
},
"SOLAR-10.7B-Chat": { "SOLAR-10.7B-Chat": {
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0", DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0", DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
@@ -567,10 +734,18 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"XuanYuan-70B": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"}, "XuanYuan-70B": {
"XuanYuan-70B-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"}, DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
"XuanYuan-70B-int8-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"}, },
"XuanYuan-70B-int4-Chat": {DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"}, "XuanYuan-70B-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
},
"XuanYuan-70B-int8-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
},
"XuanYuan-70B-int4-Chat": {
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
},
}, },
template="xuanyuan", template="xuanyuan",
) )
@@ -578,9 +753,18 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"XVERSE-7B": {DownloadSource.DEFAULT: "xverse/XVERSE-7B", DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"}, "XVERSE-7B": {
"XVERSE-13B": {DownloadSource.DEFAULT: "xverse/XVERSE-13B", DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"}, DownloadSource.DEFAULT: "xverse/XVERSE-7B",
"XVERSE-65B": {DownloadSource.DEFAULT: "xverse/XVERSE-65B", DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"}, DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
},
"XVERSE-13B": {
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
},
"XVERSE-65B": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
},
"XVERSE-65B-2": { "XVERSE-65B-2": {
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2", DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2", DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
@@ -619,10 +803,22 @@ register_model_group(
register_model_group( register_model_group(
models={ models={
"Yi-6B": {DownloadSource.DEFAULT: "01-ai/Yi-6B", DownloadSource.MODELSCOPE: "01ai/Yi-6B"}, "Yi-6B": {
"Yi-34B": {DownloadSource.DEFAULT: "01-ai/Yi-34B", DownloadSource.MODELSCOPE: "01ai/Yi-34B"}, DownloadSource.DEFAULT: "01-ai/Yi-6B",
"Yi-6B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"}, DownloadSource.MODELSCOPE: "01ai/Yi-6B",
"Yi-34B-Chat": {DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"}, },
"Yi-34B": {
DownloadSource.DEFAULT: "01-ai/Yi-34B",
DownloadSource.MODELSCOPE: "01ai/Yi-34B",
},
"Yi-6B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat",
},
"Yi-34B-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
},
"Yi-6B-int8-Chat": { "Yi-6B-int8-Chat": {
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits", DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits", DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",

View File

@@ -10,6 +10,7 @@ from transformers.utils import (
WEIGHTS_NAME, WEIGHTS_NAME,
is_torch_bf16_gpu_available, is_torch_bf16_gpu_available,
is_torch_cuda_available, is_torch_cuda_available,
is_torch_mps_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_xpu_available, is_torch_xpu_available,
) )
@@ -133,6 +134,8 @@ def get_current_device() -> torch.device:
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_npu_available(): elif is_torch_npu_available():
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0")) device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_mps_available():
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
elif is_torch_cuda_available(): elif is_torch_cuda_available():
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0")) device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
else: else:

View File

@@ -2,11 +2,11 @@ import importlib.metadata
import importlib.util import importlib.util
def is_package_available(name: str) -> bool: def _is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None return importlib.util.find_spec(name) is not None
def get_package_version(name: str) -> str: def _get_package_version(name: str) -> str:
try: try:
return importlib.metadata.version(name) return importlib.metadata.version(name)
except Exception: except Exception:
@@ -14,36 +14,40 @@ def get_package_version(name: str) -> str:
def is_fastapi_availble(): def is_fastapi_availble():
return is_package_available("fastapi") return _is_package_available("fastapi")
def is_flash_attn2_available(): def is_flash_attn2_available():
return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2") return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")
def is_jieba_available(): def is_jieba_available():
return is_package_available("jieba") return _is_package_available("jieba")
def is_matplotlib_available(): def is_matplotlib_available():
return is_package_available("matplotlib") return _is_package_available("matplotlib")
def is_nltk_available(): def is_nltk_available():
return is_package_available("nltk") return _is_package_available("nltk")
def is_requests_available(): def is_requests_available():
return is_package_available("requests") return _is_package_available("requests")
def is_rouge_available(): def is_rouge_available():
return is_package_available("rouge_chinese") return _is_package_available("rouge_chinese")
def is_starlette_available(): def is_starlette_available():
return is_package_available("sse_starlette") return _is_package_available("sse_starlette")
def is_unsloth_available():
return _is_package_available("unsloth")
def is_uvicorn_available(): def is_uvicorn_available():
return is_package_available("uvicorn") return _is_package_available("uvicorn")

View File

@@ -0,0 +1,38 @@
import torch
import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP, MixtralSparseMoeBlock
def mlp_forward(self: "MixtralBLockSparseTop2MLP", hidden_states: torch.Tensor) -> torch.Tensor:
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
# Modified from: https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
def moe_forward(self: "MixtralSparseMoeBlock", hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states)
flat_topk_idx = topk_idx.view(-1)
for i in range(self.num_experts):
expert = self.experts[i]
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
def patch_mixtral_replace_moe_impl() -> None:
MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward

View File

@@ -7,31 +7,42 @@ class DataArguments:
r""" r"""
Arguments pertaining to what data we are going to input our model for training and evaluation. Arguments pertaining to what data we are going to input our model for training and evaluation.
""" """
template: Optional[str] = field( template: Optional[str] = field(
default=None, metadata={"help": "Which template to use for constructing prompts in training and inference."} default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."},
) )
dataset: Optional[str] = field( dataset: Optional[str] = field(
default=None, default=None,
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}, metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
) )
dataset_dir: Optional[str] = field( dataset_dir: Optional[str] = field(
default="data", metadata={"help": "Path to the folder containing the datasets."} default="data",
metadata={"help": "Path to the folder containing the datasets."},
) )
split: Optional[str] = field( split: Optional[str] = field(
default="train", metadata={"help": "Which dataset split to use for training and evaluation."} default="train",
metadata={"help": "Which dataset split to use for training and evaluation."},
) )
cutoff_len: Optional[int] = field( cutoff_len: Optional[int] = field(
default=1024, metadata={"help": "The maximum length of the model inputs after tokenization."} default=1024,
metadata={"help": "The cutoff length of the model inputs after tokenization."},
) )
reserved_label_len: Optional[int] = field( reserved_label_len: Optional[int] = field(
default=1, metadata={"help": "The maximum length reserved for label after tokenization."} default=1,
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
) )
train_on_prompt: Optional[bool] = field( train_on_prompt: Optional[bool] = field(
default=False, metadata={"help": "Whether to disable the mask on the prompt or not."} default=False,
metadata={"help": "Whether to disable the mask on the prompt or not."},
)
streaming: Optional[bool] = field(
default=False,
metadata={"help": "Enable dataset streaming."},
) )
streaming: Optional[bool] = field(default=False, metadata={"help": "Enable dataset streaming."})
buffer_size: Optional[int] = field( buffer_size: Optional[int] = field(
default=16384, metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."} default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
) )
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat", default="concat",
@@ -42,13 +53,16 @@ class DataArguments:
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}, metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
) )
overwrite_cache: Optional[bool] = field( overwrite_cache: Optional[bool] = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets."} default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."},
) )
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: Optional[int] = field(
default=None, metadata={"help": "The number of processes to use for the preprocessing."} default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
) )
max_samples: Optional[int] = field( max_samples: Optional[int] = field(
default=None, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} default=None,
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
) )
eval_num_beams: Optional[int] = field( eval_num_beams: Optional[int] = field(
default=None, default=None,
@@ -57,17 +71,20 @@ class DataArguments:
ignore_pad_token_for_loss: Optional[bool] = field( ignore_pad_token_for_loss: Optional[bool] = field(
default=True, default=True,
metadata={ metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." "help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
}, },
) )
val_size: Optional[float] = field( val_size: Optional[float] = field(
default=0, metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."} default=0,
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
) )
sft_packing: Optional[bool] = field( sft_packing: Optional[bool] = field(
default=False, metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."} default=False,
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."},
) )
cache_path: Optional[str] = field( cache_path: Optional[str] = field(
default=None, metadata={"help": "Path to save or load the preprocessed datasets."} default=None,
metadata={"help": "Path to save or load the preprocessed datasets."},
) )
def __post_init__(self): def __post_init__(self):

View File

@@ -10,15 +10,34 @@ class EvaluationArguments:
r""" r"""
Arguments pertaining to specify the evaluation parameters. Arguments pertaining to specify the evaluation parameters.
""" """
task: str = field(metadata={"help": "Name of the evaluation task."})
task_dir: Optional[str] = field( task: str = field(
default="evaluation", metadata={"help": "Path to the folder containing the evaluation datasets."} metadata={"help": "Name of the evaluation task."},
)
task_dir: Optional[str] = field(
default="evaluation",
metadata={"help": "Path to the folder containing the evaluation datasets."},
)
batch_size: Optional[int] = field(
default=4,
metadata={"help": "The batch size per GPU for evaluation."},
)
seed: Optional[int] = field(
default=42,
metadata={"help": "Random seed to be used with data loaders."},
)
lang: Optional[Literal["en", "zh"]] = field(
default="en",
metadata={"help": "Language used at evaluation."},
)
n_shot: Optional[int] = field(
default=5,
metadata={"help": "Number of examplars for few-shot learning."},
)
save_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to save the evaluation results."},
) )
batch_size: Optional[int] = field(default=4, metadata={"help": "The batch size per GPU for evaluation."})
seed: Optional[int] = field(default=42, metadata={"help": "Random seed to be used with data loaders."})
lang: Optional[Literal["en", "zh"]] = field(default="en", metadata={"help": "Language used at evaluation."})
n_shot: Optional[int] = field(default=5, metadata={"help": "Number of examplars for few-shot learning."})
save_dir: Optional[str] = field(default=None, metadata={"help": "Path to save the evaluation results."})
download_mode: Optional[DownloadMode] = field( download_mode: Optional[DownloadMode] = field(
default=DownloadMode.REUSE_DATASET_IF_EXISTS, default=DownloadMode.REUSE_DATASET_IF_EXISTS,
metadata={"help": "Download mode used for the evaluation datasets."}, metadata={"help": "Download mode used for the evaluation datasets."},

View File

@@ -8,20 +8,27 @@ class FreezeArguments:
r""" r"""
Arguments pertaining to the freeze (partial-parameter) training. Arguments pertaining to the freeze (partial-parameter) training.
""" """
name_module_trainable: Optional[str] = field( name_module_trainable: Optional[str] = field(
default="mlp", default=None,
metadata={ metadata={
"help": 'Name of trainable modules for partial-parameter (freeze) fine-tuning. \ "help": """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
Use commas to separate multiple modules. \ Use commas to separate multiple modules. \
LLaMA choices: ["mlp", "self_attn"], \ Use "all" to specify all the available modules. \
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \ LLaMA choices: ["mlp", "self_attn"], \
Qwen choices: ["mlp", "attn"], \ BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
Phi choices: ["mlp", "mixer"], \ Qwen choices: ["mlp", "attn"], \
Others choices: the same as LLaMA.' InternLM2 choices: ["feed_forward", "attention"], \
Others choices: the same as LLaMA."""
}, },
) )
num_layer_trainable: Optional[int] = field( num_layer_trainable: Optional[int] = field(
default=3, metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."} default=3,
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
)
use_llama_pro: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to use llama pro for partial-parameter (freeze) fine-tuning."},
) )
@@ -30,6 +37,7 @@ class LoraArguments:
r""" r"""
Arguments pertaining to the LoRA training. Arguments pertaining to the LoRA training.
""" """
additional_target: Optional[str] = field( additional_target: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
@@ -37,27 +45,42 @@ class LoraArguments:
}, },
) )
lora_alpha: Optional[int] = field( lora_alpha: Optional[int] = field(
default=None, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."} default=None,
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
)
lora_dropout: Optional[float] = field(
default=0.0,
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
)
lora_rank: Optional[int] = field(
default=8,
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
) )
lora_dropout: Optional[float] = field(default=0.0, metadata={"help": "Dropout rate for the LoRA fine-tuning."})
lora_rank: Optional[int] = field(default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."})
lora_target: Optional[str] = field( lora_target: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": 'Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ "help": """Name(s) of target modules to apply LoRA. \
LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \ Use commas to separate multiple modules. \
BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \ Use "all" to specify all the available modules. \
Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \ LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \ BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
Phi choices: ["Wqkv", "out_proj", "fc1", "fc2"], \ Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
Others choices: the same as LLaMA.' Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
Others choices: the same as LLaMA."""
}, },
) )
lora_bf16_mode: Optional[bool] = field( lora_bf16_mode: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to train lora adapters in bf16 precision."} default=False,
metadata={"help": "Whether or not to train lora adapters in bf16 precision."},
)
use_rslora: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
) )
create_new_adapter: Optional[bool] = field( create_new_adapter: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."} default=False,
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
) )
@@ -66,49 +89,66 @@ class RLHFArguments:
r""" r"""
Arguments pertaining to the PPO and DPO training. Arguments pertaining to the PPO and DPO training.
""" """
dpo_beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for the DPO loss."})
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."},
)
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field( dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
default="sigmoid", metadata={"help": "The type of DPO loss to use."} default="sigmoid",
metadata={"help": "The type of DPO loss to use."},
) )
dpo_ftx: Optional[float] = field( dpo_ftx: Optional[float] = field(
default=0, metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."} default=0,
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
) )
ppo_buffer_size: Optional[int] = field( ppo_buffer_size: Optional[int] = field(
default=1, default=1,
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}, metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
) )
ppo_epochs: Optional[int] = field( ppo_epochs: Optional[int] = field(
default=4, metadata={"help": "The number of epochs to perform in a PPO optimization step."} default=4,
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
) )
ppo_logger: Optional[str] = field( ppo_logger: Optional[str] = field(
default=None, metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'} default=None,
metadata={"help": 'Log with either "wandb" or "tensorboard" in PPO training.'},
) )
ppo_score_norm: Optional[bool] = field( ppo_score_norm: Optional[bool] = field(
default=False, metadata={"help": "Use score normalization in PPO training."} default=False,
metadata={"help": "Use score normalization in PPO training."},
) )
ppo_target: Optional[float] = field( ppo_target: Optional[float] = field(
default=6.0, metadata={"help": "Target KL value for adaptive KL control in PPO training."} default=6.0,
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
) )
ppo_whiten_rewards: Optional[bool] = field( ppo_whiten_rewards: Optional[bool] = field(
default=False, metadata={"help": "Whiten the rewards before compute advantages in PPO training."} default=False,
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
) )
ref_model: Optional[str] = field( ref_model: Optional[str] = field(
default=None, metadata={"help": "Path to the reference model used for the PPO or DPO training."} default=None,
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
) )
ref_model_adapters: Optional[str] = field( ref_model_adapters: Optional[str] = field(
default=None, metadata={"help": "Path to the adapters of the reference model."} default=None,
metadata={"help": "Path to the adapters of the reference model."},
) )
ref_model_quantization_bit: Optional[int] = field( ref_model_quantization_bit: Optional[int] = field(
default=None, metadata={"help": "The number of bits to quantize the reference model."} default=None,
metadata={"help": "The number of bits to quantize the reference model."},
) )
reward_model: Optional[str] = field( reward_model: Optional[str] = field(
default=None, metadata={"help": "Path to the reward model used for the PPO training."} default=None,
metadata={"help": "Path to the reward model used for the PPO training."},
) )
reward_model_adapters: Optional[str] = field( reward_model_adapters: Optional[str] = field(
default=None, metadata={"help": "Path to the adapters of the reward model."} default=None,
metadata={"help": "Path to the adapters of the reward model."},
) )
reward_model_quantization_bit: Optional[int] = field( reward_model_quantization_bit: Optional[int] = field(
default=None, metadata={"help": "The number of bits to quantize the reward model."} default=None,
metadata={"help": "The number of bits to quantize the reward model."},
) )
reward_model_type: Optional[Literal["lora", "full", "api"]] = field( reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
default="lora", default="lora",
@@ -121,14 +161,22 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
r""" r"""
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which techniques we are going to fine-tuning with.
""" """
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft", metadata={"help": "Which stage will be performed in training."} default="sft",
metadata={"help": "Which stage will be performed in training."},
) )
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field( finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora", metadata={"help": "Which fine-tuning method to use."} default="lora",
metadata={"help": "Which fine-tuning method to use."},
)
disable_version_checking: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to disable version checking."},
) )
plot_loss: Optional[bool] = field( plot_loss: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to save the training loss curves."} default=False,
metadata={"help": "Whether or not to save the training loss curves."},
) )
def __post_init__(self): def __post_init__(self):
@@ -152,6 +200,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.") raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
if self.use_llama_pro and self.finetuning_type != "freeze":
raise ValueError("`use_llama_pro` is only valid for the Freeze method.")
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
r"""Saves the content of this instance in JSON format inside `json_path`.""" r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"

View File

@@ -7,11 +7,14 @@ class GeneratingArguments:
r""" r"""
Arguments pertaining to specify the decoding parameters. Arguments pertaining to specify the decoding parameters.
""" """
do_sample: Optional[bool] = field( do_sample: Optional[bool] = field(
default=True, metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} default=True,
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."},
) )
temperature: Optional[float] = field( temperature: Optional[float] = field(
default=0.95, metadata={"help": "The value used to modulate the next token probabilities."} default=0.95,
metadata={"help": "The value used to modulate the next token probabilities."},
) )
top_p: Optional[float] = field( top_p: Optional[float] = field(
default=0.7, default=0.7,
@@ -24,7 +27,8 @@ class GeneratingArguments:
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}, metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."},
) )
num_beams: Optional[int] = field( num_beams: Optional[int] = field(
default=1, metadata={"help": "Number of beams for beam search. 1 means no beam search."} default=1,
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
) )
max_length: Optional[int] = field( max_length: Optional[int] = field(
default=512, default=512,
@@ -35,10 +39,12 @@ class GeneratingArguments:
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}, metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
) )
repetition_penalty: Optional[float] = field( repetition_penalty: Optional[float] = field(
default=1.0, metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} default=1.0,
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."},
) )
length_penalty: Optional[float] = field( length_penalty: Optional[float] = field(
default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."} default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
) )
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:

View File

@@ -7,11 +7,15 @@ class ModelArguments:
r""" r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune. Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
""" """
model_name_or_path: str = field( model_name_or_path: str = field(
metadata={"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."} metadata={
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
},
) )
adapter_name_or_path: Optional[str] = field( adapter_name_or_path: Optional[str] = field(
default=None, metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."} default=None,
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."},
) )
cache_dir: Optional[str] = field( cache_dir: Optional[str] = field(
default=None, default=None,
@@ -22,7 +26,8 @@ class ModelArguments:
metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."},
) )
resize_vocab: Optional[bool] = field( resize_vocab: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."} default=False,
metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."},
) )
split_special_tokens: Optional[bool] = field( split_special_tokens: Optional[bool] = field(
default=False, default=False,
@@ -33,60 +38,88 @@ class ModelArguments:
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
) )
quantization_bit: Optional[int] = field( quantization_bit: Optional[int] = field(
default=None, metadata={"help": "The number of bits to quantize the model."} default=None,
metadata={"help": "The number of bits to quantize the model."},
) )
quantization_type: Optional[Literal["fp4", "nf4"]] = field( quantization_type: Optional[Literal["fp4", "nf4"]] = field(
default="nf4", metadata={"help": "Quantization data type to use in int4 training."} default="nf4",
metadata={"help": "Quantization data type to use in int4 training."},
) )
double_quantization: Optional[bool] = field( double_quantization: Optional[bool] = field(
default=True, metadata={"help": "Whether or not to use double quantization in int4 training."} default=True,
metadata={"help": "Whether or not to use double quantization in int4 training."},
) )
rope_scaling: Optional[Literal["linear", "dynamic"]] = field( rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."} default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
) )
flash_attn: Optional[bool] = field( flash_attn: Optional[bool] = field(
default=False, metadata={"help": "Enable FlashAttention-2 for faster training."} default=False,
metadata={"help": "Enable FlashAttention-2 for faster training."},
) )
shift_attn: Optional[bool] = field( shift_attn: Optional[bool] = field(
default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
) )
use_unsloth: Optional[bool] = field( use_unsloth: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."} default=False,
metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."},
) )
disable_gradient_checkpointing: Optional[bool] = field( disable_gradient_checkpointing: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to disable gradient checkpointing."} default=False,
metadata={"help": "Whether or not to disable gradient checkpointing."},
) )
upcast_layernorm: Optional[bool] = field( upcast_layernorm: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to upcast the layernorm weights in fp32."} default=False,
metadata={"help": "Whether or not to upcast the layernorm weights in fp32."},
) )
upcast_lmhead_output: Optional[bool] = field( upcast_lmhead_output: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to upcast the output of lm_head in fp32."} default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
)
hf_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."},
)
ms_hub_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with ModelScope Hub."},
) )
hf_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."})
ms_hub_token: Optional[str] = field(default=None, metadata={"help": "Auth token to log in with ModelScope Hub."})
export_dir: Optional[str] = field( export_dir: Optional[str] = field(
default=None, metadata={"help": "Path to the directory to save the exported model."} default=None,
metadata={"help": "Path to the directory to save the exported model."},
) )
export_size: Optional[int] = field( export_size: Optional[int] = field(
default=1, metadata={"help": "The file shard size (in GB) of the exported model."} default=1,
metadata={"help": "The file shard size (in GB) of the exported model."},
) )
export_quantization_bit: Optional[int] = field( export_quantization_bit: Optional[int] = field(
default=None, metadata={"help": "The number of bits to quantize the exported model."} default=None,
metadata={"help": "The number of bits to quantize the exported model."},
) )
export_quantization_dataset: Optional[str] = field( export_quantization_dataset: Optional[str] = field(
default=None, metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."} default=None,
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
) )
export_quantization_nsamples: Optional[int] = field( export_quantization_nsamples: Optional[int] = field(
default=128, metadata={"help": "The number of samples used for quantization."} default=128,
metadata={"help": "The number of samples used for quantization."},
) )
export_quantization_maxlen: Optional[int] = field( export_quantization_maxlen: Optional[int] = field(
default=1024, metadata={"help": "The maximum length of the model inputs used for quantization."} default=1024,
metadata={"help": "The maximum length of the model inputs used for quantization."},
) )
export_legacy_format: Optional[bool] = field( export_legacy_format: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."} default=False,
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
) )
export_hub_model_id: Optional[str] = field( export_hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository if push the model to the Hugging Face hub."} default=None,
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
)
print_param_status: Optional[bool] = field(
default=False,
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
) )
def __post_init__(self): def __post_init__(self):

View File

@@ -8,8 +8,10 @@ import torch
import transformers import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.utils.versions import require_version
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_unsloth_available
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
@@ -28,6 +30,17 @@ _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArgu
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
def _check_dependencies(disabled: bool) -> None:
if disabled:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.8.2", "To fix: pip install peft>=0.8.2")
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None: if args is not None:
return parser.parse_dict(args) return parser.parse_dict(args)
@@ -60,16 +73,15 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.") raise ValueError("Quantization is only compatible with the LoRA method.")
if finetuning_args.create_new_adapter: if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
raise ValueError("Cannot create new adapter upon a quantized model.") raise ValueError("Cannot create new adapter upon a quantized model.")
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Multiple adapters are only available for LoRA tuning.")
if model_args.quantization_bit is not None:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.") raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Adapter is only valid for the LoRA method.")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS) parser = HfArgumentParser(_TRAIN_ARGS)
@@ -121,10 +133,29 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and training_args.predict_with_generate: if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.") raise ValueError("`predict_with_generate` cannot be set as True while training.")
if (
training_args.do_train
and finetuning_args.finetuning_type == "freeze"
and finetuning_args.name_module_trainable is None
):
raise ValueError("Please specify `name_module_trainable` in Freeze training.")
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None: if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
raise ValueError("Please specify `lora_target` in LoRA training.") raise ValueError("Please specify `lora_target` in LoRA training.")
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
raise ValueError("Install Unsloth: https://github.com/unslothai/unsloth")
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"
and model_args.resize_vocab
and finetuning_args.additional_target is None
):
logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.")
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.") logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
@@ -138,7 +169,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None: if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.") logger.warning("Specify `ref_model` for computing rewards at evaluation.")
# postprocess training_args # Post-process training arguments
if ( if (
training_args.local_rank != -1 training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None and training_args.ddp_find_unused_parameters is None
@@ -151,7 +182,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False can_resume_from_checkpoint = False
training_args.resume_from_checkpoint = None if training_args.resume_from_checkpoint is not None:
logger.warning("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None
else: else:
can_resume_from_checkpoint = True can_resume_from_checkpoint = True
@@ -187,7 +220,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
) )
) )
# postprocess model_args # Post-process model arguments
model_args.compute_dtype = ( model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None) torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
) )
@@ -205,7 +238,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
) )
logger.info(f"Training/evaluation parameters {training_args}") logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
transformers.set_seed(training_args.seed) transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, generating_args return model_args, data_args, training_args, finetuning_args, generating_args
@@ -213,25 +245,27 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
_set_transformers_logging() _set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if data_args.template is None: if data_args.template is None:
raise ValueError("Please specify which `template` to use.") raise ValueError("Please specify which `template` to use.")
_verify_model_args(model_args, finetuning_args)
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args
def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
_set_transformers_logging() _set_transformers_logging()
_verify_model_args(model_args, finetuning_args)
_check_dependencies(disabled=finetuning_args.disable_version_checking)
if data_args.template is None: if data_args.template is None:
raise ValueError("Please specify which `template` to use.") raise ValueError("Please specify which `template` to use.")
_verify_model_args(model_args, finetuning_args)
transformers.set_seed(eval_args.seed) transformers.set_seed(eval_args.seed)
return model_args, data_args, eval_args, finetuning_args return model_args, data_args, eval_args, finetuning_args

View File

@@ -1,5 +1,5 @@
from .loader import load_model_and_tokenizer from .loader import load_model_and_tokenizer
from .utils import dispatch_model, get_modelcard_args, load_valuehead_params from .utils import dispatch_model, load_valuehead_params
__all__ = ["load_model_and_tokenizer", "dispatch_model", "get_modelcard_args", "load_valuehead_params"] __all__ = ["load_model_and_tokenizer", "dispatch_model", "load_valuehead_params"]

View File

@@ -1,8 +1,7 @@
import inspect
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
from peft import LoraConfig, PeftModel, TaskType, get_peft_model from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras.logging import get_logger from ..extras.logging import get_logger
@@ -47,21 +46,41 @@ def init_adapter(
if not num_layers: if not num_layers:
raise ValueError("Current model does not support freeze tuning.") raise ValueError("Current model does not support freeze tuning.")
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 if finetuning_args.use_llama_pro:
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)] if num_layers % finetuning_args.num_layer_trainable != 0:
raise ValueError(
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(
num_layers, finetuning_args.num_layer_trainable
)
)
stride = num_layers // finetuning_args.num_layer_trainable
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
elif finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = range(num_layers - finetuning_args.num_layer_trainable, num_layers)
else: # fine-tuning the first n layers if num_layer_trainable < 0 else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] # noqa: C416 trainable_layer_ids = range(-finetuning_args.num_layer_trainable)
freeze_modules = {"all"}
for name, _ in model.named_modules():
if ".0." in name:
freeze_modules.add(name.split(".0.")[-1].split(".")[0])
trainable_layers = [] trainable_layers = []
for module_name in finetuning_args.name_module_trainable: for module_name in finetuning_args.name_module_trainable:
if module_name not in freeze_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules))
)
for idx in trainable_layer_ids: for idx in trainable_layer_ids:
trainable_layers.append("{:d}.{}".format(idx, module_name)) trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else ""))
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in trainable_layers): if any(trainable_layer in name for trainable_layer in trainable_layers):
param.requires_grad_(False)
else:
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
else:
param.requires_grad_(False)
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA") logger.info("Fine-tuning method: LoRA")
@@ -84,7 +103,7 @@ def init_adapter(
adapter_to_merge = model_args.adapter_name_or_path adapter_to_merge = model_args.adapter_name_or_path
for adapter in adapter_to_merge: for adapter in adapter_to_merge:
model = PeftModel.from_pretrained(model, adapter) model: "LoraModel" = PeftModel.from_pretrained(model, adapter)
model = model.merge_and_unload() model = model.merge_and_unload()
if len(adapter_to_merge) > 0: if len(adapter_to_merge) > 0:
@@ -104,22 +123,14 @@ def init_adapter(
"target_modules": target_modules, "target_modules": target_modules,
"lora_alpha": finetuning_args.lora_alpha, "lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout, "lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora,
} }
if model_args.use_unsloth: if model_args.use_unsloth:
from unsloth import FastLlamaModel, FastMistralModel # type: ignore from unsloth import FastLanguageModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length} unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
if "loftq_config" in inspect.signature(FastLlamaModel.get_peft_model).parameters: model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
unsloth_peft_kwargs["loftq_config"] = {}
if getattr(model.config, "model_type", None) == "llama":
model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
elif getattr(model.config, "model_type", None) == "mistral":
model = FastMistralModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
else:
raise NotImplementedError
else: else:
lora_config = LoraConfig( lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
@@ -132,7 +143,7 @@ def init_adapter(
for param in filter(lambda p: p.requires_grad, model.parameters()): for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32) param.data = param.data.to(torch.bfloat16 if finetuning_args.lora_bf16_mode else torch.float32)
if model_args.adapter_name_or_path is not None: if model_args.adapter_name_or_path is not None:
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
return model return model

View File

@@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Optional, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from ..extras.logging import get_logger from ..extras.logging import get_logger
@@ -21,13 +20,6 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
require_version("transformers>=4.36.2", "To fix: pip install transformers>=4.36.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
def load_model_and_tokenizer( def load_model_and_tokenizer(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
@@ -63,8 +55,7 @@ def load_model_and_tokenizer(
model = None model = None
if is_trainable and model_args.use_unsloth: if is_trainable and model_args.use_unsloth:
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth") from unsloth import FastLanguageModel # type: ignore
from unsloth import FastLlamaModel, FastMistralModel # type: ignore
unsloth_kwargs = { unsloth_kwargs = {
"model_name": model_args.model_name_or_path, "model_name": model_args.model_name_or_path,
@@ -72,14 +63,12 @@ def load_model_and_tokenizer(
"dtype": model_args.compute_dtype, "dtype": model_args.compute_dtype,
"load_in_4bit": model_args.quantization_bit == 4, "load_in_4bit": model_args.quantization_bit == 4,
"token": model_args.hf_hub_token, "token": model_args.hf_hub_token,
"device_map": get_current_device(), "device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None), "rope_scaling": getattr(config, "rope_scaling", None),
} }
if getattr(config, "model_type", None) == "llama": try:
model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs) model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
elif getattr(config, "model_type", None) == "mistral": except NotImplementedError:
model, _ = FastMistralModel.from_pretrained(**unsloth_kwargs)
else:
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
model_args.use_unsloth = False model_args.use_unsloth = False
@@ -132,4 +121,12 @@ def load_model_and_tokenizer(
if not is_trainable: if not is_trainable:
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.") logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
if model_args.print_param_status:
for name, param in model.named_parameters():
print(
"name: {}, dtype: {}, device: {}, trainable: {}".format(
name, param.dtype, param.device, param.requires_grad
)
)
return model, tokenizer return model, tokenizer

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch import torch
from datasets import load_dataset from datasets import load_dataset
from peft import PeftModel
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
@@ -16,6 +17,7 @@ from ..extras.logging import get_logger
from ..extras.misc import get_current_device, infer_optim_dtype from ..extras.misc import get_current_device, infer_optim_dtype
from ..extras.packages import is_flash_attn2_available from ..extras.packages import is_flash_attn2_available
from ..extras.patches.llama_patch import apply_llama_patch from ..extras.patches.llama_patch import apply_llama_patch
from ..extras.patches.mixtral_patch import patch_mixtral_replace_moe_impl
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -100,6 +102,18 @@ def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "Mod
return samples return samples
def _configure_attn_implementation(model_args: "ModelArguments", config_kwargs: Dict[str, Any]) -> None:
if model_args.flash_attn:
if is_flash_attn2_available():
config_kwargs["attn_implementation"] = "flash_attention_2"
logger.info("Using FlashAttention-2 for faster training and inference.")
else:
logger.warning("FlashAttention2 is not installed.")
config_kwargs["attn_implementation"] = None
else:
config_kwargs["attn_implementation"] = "eager"
def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not hasattr(config, "rope_scaling"): if not hasattr(config, "rope_scaling"):
logger.warning("Current model does not support RoPE scaling.") logger.warning("Current model does not support RoPE scaling.")
@@ -127,15 +141,6 @@ def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is
) )
def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
if not is_flash_attn2_available():
logger.warning("FlashAttention2 is not installed.")
return
config_kwargs["use_flash_attention_2"] = True
logger.info("Using FlashAttention-2 for faster training and inference.")
def _configure_longlora(config: "PretrainedConfig") -> None: def _configure_longlora(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
setattr(config, "group_size_ratio", 0.25) setattr(config, "group_size_ratio", 0.25)
@@ -222,7 +227,10 @@ def _prepare_model_for_training(
if not getattr(model, "supports_gradient_checkpointing", False): if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.") logger.warning("Current model does not support gradient checkpointing.")
else: else:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) # use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.enable_input_require_grads()
model.config.use_cache = False # turn off when gradient checkpointing is enabled model.config.use_cache = False # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.") logger.info("Gradient checkpointing enabled.")
@@ -255,12 +263,11 @@ def patch_config(
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, model_args.compute_dtype == dtype) setattr(config, dtype_name, model_args.compute_dtype == dtype)
_configure_attn_implementation(model_args, config_kwargs)
if model_args.rope_scaling is not None: if model_args.rope_scaling is not None:
_configure_rope(config, model_args, is_trainable) _configure_rope(config, model_args, is_trainable)
if model_args.flash_attn:
_configure_flashattn(config_kwargs)
if is_trainable and model_args.shift_attn: if is_trainable and model_args.shift_attn:
_configure_longlora(config) _configure_longlora(config)
@@ -283,6 +290,21 @@ def patch_model(
if is_trainable: if is_trainable:
_prepare_model_for_training(model, model_args) _prepare_model_for_training(model, model_args)
if getattr(model.config, "model_type", None) == "mixtral" and is_deepspeed_zero3_enabled():
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
if is_trainable:
patch_mixtral_replace_moe_impl()
try:
model.add_model_tags(["llama-factory"])
except Exception:
logger.warning("Cannot properly tag the model.")
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
@@ -293,7 +315,12 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
if isinstance(self.pretrained_model, PreTrainedModel): if isinstance(self.pretrained_model, PreTrainedModel):
return self.pretrained_model.get_input_embeddings() return self.pretrained_model.get_input_embeddings()
def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
if isinstance(self.pretrained_model, PeftModel):
self.pretrained_model.create_or_update_model_card(output_dir)
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules) setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "tie_weights", MethodType(tie_weights, model)) setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model))

View File

@@ -1,5 +1,5 @@
import inspect import inspect
from typing import TYPE_CHECKING, Any, Dict, List from typing import TYPE_CHECKING, Dict, List
import torch import torch
from transformers import PreTrainedModel from transformers import PreTrainedModel
@@ -13,7 +13,7 @@ from ..extras.misc import get_current_device
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer from transformers import PretrainedConfig, PreTrainedTokenizer
from ..hparams import DataArguments, FinetuningArguments, ModelArguments from ..hparams import ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -41,7 +41,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
# Make sure tied weights are tied before creating the device map. # Make sure tied weights are tied before creating the device map.
model.tie_weights() model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
device_map_kwargs = {"device_map": device_map} device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"}
if "skip_keys" in inspect.signature(dispatch_model).parameters: if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
return dispatch_model(model, **device_map_kwargs) return dispatch_model(model, **device_map_kwargs)
@@ -76,18 +76,6 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
return list(module_names) return list(module_names)
def get_modelcard_args(
model_args: "ModelArguments", data_args: "DataArguments", finetuning_args: "FinetuningArguments"
) -> Dict[str, Any]:
return {
"tasks": "text-generation",
"license": "other",
"finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else []),
}
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r""" r"""
Loads value head parameters from Hugging Face Hub or local disk. Loads value head parameters from Hugging Face Hub or local disk.

View File

@@ -1,7 +1,9 @@
import json import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional from typing import TYPE_CHECKING, Dict, List, Literal, Optional
import torch import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.packages import is_requests_available from ...extras.packages import is_requests_available
@@ -23,18 +25,22 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily if is_deepspeed_zero3_enabled():
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict() import deepspeed # type: ignore
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active params = [model.v_head.summary.weight, model.v_head.summary.bias]
model.v_head.load_state_dict( context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
{ else:
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(), context_maybe_zero3 = nullcontext()
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone(),
} with context_maybe_zero3:
) if target == "reward": # save default head temporarily
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone()
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone()
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:

View File

@@ -56,12 +56,13 @@ def export_model(args: Optional[Dict[str, Any]] = None):
if not isinstance(model, PreTrainedModel): if not isinstance(model, PreTrainedModel):
raise ValueError("The model is not a `PreTrainedModel`, export aborted.") raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
setattr(model.config, "use_cache", True) if getattr(model, "quantization_method", None):
if getattr(model.config, "torch_dtype", None) == "bfloat16": model = model.to("cpu")
model = model.to(torch.bfloat16).to("cpu") elif hasattr(model.config, "torch_dtype"):
model = model.to(getattr(model.config, "torch_dtype")).to("cpu")
else: else:
model = model.to(torch.float16).to("cpu") model = model.to(torch.float16).to("cpu")
setattr(model.config, "torch_dtype", "float16") setattr(model.config, "torch_dtype", torch.float16)
model.save_pretrained( model.save_pretrained(
save_directory=model_args.export_dir, save_directory=model_args.export_dir,

View File

@@ -4,7 +4,7 @@ import torch
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..hparams import FinetuningArguments, ModelArguments from ..hparams import FinetuningArguments, ModelArguments
from ..model import get_modelcard_args, load_model_and_tokenizer, load_valuehead_params from ..model import load_model_and_tokenizer, load_valuehead_params
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -25,14 +25,18 @@ def create_modelcard_and_push(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
) -> None: ) -> None:
if training_args.do_train: kwargs = {
if training_args.push_to_hub: "tasks": "text-generation",
trainer.push_to_hub(**get_modelcard_args(model_args, data_args, finetuning_args)) "finetuned_from": model_args.model_name_or_path,
return "dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
try: "tags": ["llama-factory", finetuning_args.finetuning_type],
trainer.create_model_card(**get_modelcard_args(model_args, data_args, finetuning_args)) }
except Exception as err: if not training_args.do_train:
logger.warning("Failed to create model card: {}".format(str(err))) pass
elif training_args.push_to_hub:
trainer.push_to_hub(**kwargs)
else:
trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub
def create_ref_model( def create_ref_model(

View File

@@ -26,6 +26,7 @@ def save_model(
max_shard_size: int, max_shard_size: int,
export_quantization_bit: int, export_quantization_bit: int,
export_quantization_dataset: str, export_quantization_dataset: str,
export_legacy_format: bool,
export_dir: str, export_dir: str,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
error = "" error = ""
@@ -61,6 +62,7 @@ def save_model(
export_size=max_shard_size, export_size=max_shard_size,
export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None, export_quantization_bit=int(export_quantization_bit) if export_quantization_bit in GPTQ_BITS else None,
export_quantization_dataset=export_quantization_dataset, export_quantization_dataset=export_quantization_dataset,
export_legacy_format=export_legacy_format,
) )
yield ALERTS["info_exporting"][lang] yield ALERTS["info_exporting"][lang]
@@ -73,6 +75,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
max_shard_size = gr.Slider(value=1, minimum=1, maximum=100) max_shard_size = gr.Slider(value=1, minimum=1, maximum=100)
export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none") export_quantization_bit = gr.Dropdown(choices=["none", "8", "4", "3", "2"], value="none")
export_quantization_dataset = gr.Textbox(value="data/c4_demo.json") export_quantization_dataset = gr.Textbox(value="data/c4_demo.json")
export_legacy_format = gr.Checkbox()
export_dir = gr.Textbox() export_dir = gr.Textbox()
export_btn = gr.Button() export_btn = gr.Button()
@@ -90,6 +93,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
max_shard_size, max_shard_size,
export_quantization_bit, export_quantization_bit,
export_quantization_dataset, export_quantization_dataset,
export_legacy_format,
export_dir, export_dir,
], ],
[info_box], [info_box],
@@ -99,6 +103,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
max_shard_size=max_shard_size, max_shard_size=max_shard_size,
export_quantization_bit=export_quantization_bit, export_quantization_bit=export_quantization_bit,
export_quantization_dataset=export_quantization_dataset, export_quantization_dataset=export_quantization_dataset,
export_legacy_format=export_legacy_format,
export_dir=export_dir, export_dir=export_dir,
export_btn=export_btn, export_btn=export_btn,
info_box=info_box, info_box=info_box,

View File

@@ -16,7 +16,7 @@ def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row(): with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], scale=1) lang = gr.Dropdown(choices=["en", "ru", "zh"], scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3) model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3) model_path = gr.Textbox(scale=3)
@@ -30,7 +30,7 @@ def create_top() -> Dict[str, "Component"]:
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none") quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
template = gr.Dropdown(choices=list(templates.keys()), value="default") template = gr.Dropdown(choices=list(templates.keys()), value="default")
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none") rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
booster = gr.Radio(choices=["none", "flash_attn", "unsloth"], value="none") booster = gr.Radio(choices=["none", "flashattn", "unsloth"], value="none")
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then( model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
get_model_path, [model_name], [model_path], queue=False get_model_path, [model_name], [model_path], queue=False

View File

@@ -52,8 +52,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
) )
with gr.Row(): with gr.Row():
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) batch_size = gr.Slider(value=4, minimum=1, maximum=1024, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=1024, step=1)
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine") lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
max_grad_norm = gr.Textbox(value="1.0") max_grad_norm = gr.Textbox(value="1.0")
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
@@ -76,11 +76,24 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1) neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
with gr.Column(): with gr.Row():
sft_packing = gr.Checkbox(value=False) resize_vocab = gr.Checkbox()
upcast_layernorm = gr.Checkbox(value=False) sft_packing = gr.Checkbox()
upcast_layernorm = gr.Checkbox()
use_llama_pro = gr.Checkbox()
input_elems.update({logging_steps, save_steps, warmup_steps, neftune_alpha, sft_packing, upcast_layernorm}) input_elems.update(
{
logging_steps,
save_steps,
warmup_steps,
neftune_alpha,
resize_vocab,
sft_packing,
upcast_layernorm,
use_llama_pro,
}
)
elem_dict.update( elem_dict.update(
dict( dict(
extra_tab=extra_tab, extra_tab=extra_tab,
@@ -88,20 +101,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
save_steps=save_steps, save_steps=save_steps,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha, neftune_alpha=neftune_alpha,
resize_vocab=resize_vocab,
sft_packing=sft_packing, sft_packing=sft_packing,
upcast_layernorm=upcast_layernorm, upcast_layernorm=upcast_layernorm,
use_llama_pro=use_llama_pro,
) )
) )
with gr.Accordion(label="LoRA config", open=False) as lora_tab: with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row(): with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1) lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
lora_target = gr.Textbox(scale=1) lora_target = gr.Textbox()
additional_target = gr.Textbox(scale=1) additional_target = gr.Textbox()
create_new_adapter = gr.Checkbox(scale=1)
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, create_new_adapter}) with gr.Column():
use_rslora = gr.Checkbox()
create_new_adapter = gr.Checkbox()
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, use_rslora, create_new_adapter})
elem_dict.update( elem_dict.update(
dict( dict(
lora_tab=lora_tab, lora_tab=lora_tab,
@@ -109,6 +127,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
lora_target=lora_target, lora_target=lora_target,
additional_target=additional_target, additional_target=additional_target,
use_rslora=use_rslora,
create_new_adapter=create_new_adapter, create_new_adapter=create_new_adapter,
) )
) )
@@ -143,7 +162,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
output_dir = gr.Textbox() output_dir = gr.Textbox()
with gr.Row(): with gr.Row():
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False) resume_btn = gr.Checkbox(visible=False, interactive=False)
process_bar = gr.Slider(visible=False, interactive=False) process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box(): with gr.Box():

File diff suppressed because it is too large Load Diff

View File

@@ -125,12 +125,15 @@ class Runner:
save_steps=get("train.save_steps"), save_steps=get("train.save_steps"),
warmup_steps=get("train.warmup_steps"), warmup_steps=get("train.warmup_steps"),
neftune_noise_alpha=get("train.neftune_alpha") or None, neftune_noise_alpha=get("train.neftune_alpha") or None,
resize_vocab=get("train.resize_vocab"),
sft_packing=get("train.sft_packing"), sft_packing=get("train.sft_packing"),
upcast_layernorm=get("train.upcast_layernorm"), upcast_layernorm=get("train.upcast_layernorm"),
use_llama_pro=get("train.use_llama_pro"),
lora_rank=get("train.lora_rank"), lora_rank=get("train.lora_rank"),
lora_dropout=get("train.lora_dropout"), lora_dropout=get("train.lora_dropout"),
lora_target=get("train.lora_target") or get_module(get("top.model_name")), lora_target=get("train.lora_target") or get_module(get("top.model_name")),
additional_target=get("train.additional_target") or None, additional_target=get("train.additional_target") or None,
use_rslora=get("train.use_rslora"),
create_new_adapter=get("train.create_new_adapter"), create_new_adapter=get("train.create_new_adapter"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")), output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"), fp16=(get("train.compute_type") == "fp16"),

View File

@@ -10,7 +10,7 @@ import fire
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llmtuner.data import get_dataset from llmtuner.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
@@ -24,26 +24,35 @@ BASE_BS = 4_000_000 # from llama paper
def calculate_lr( def calculate_lr(
model_name_or_path: str, model_name_or_path: str,
dataset: str,
cutoff_len: int, # i.e. maximum input length during training
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size) batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
is_mistral: bool, # mistral model uses a smaller learning rate, stage: Optional[str] = "sft",
dataset: Optional[str] = "alpaca_en",
dataset_dir: Optional[str] = "data", dataset_dir: Optional[str] = "data",
template: Optional[str] = "default",
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
): ):
model_args, data_args, training_args, finetuning_args, _ = get_train_args( model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict( dict(
stage="sft", stage=stage,
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
dataset=dataset, dataset=dataset,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
template="default", template=template,
cutoff_len=cutoff_len, cutoff_len=cutoff_len,
output_dir="dummy_dir", output_dir="dummy_dir",
overwrite_cache=True,
) )
) )
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False) _, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft") trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage=stage)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
else:
raise NotImplementedError
dataloader = DataLoader( dataloader = DataLoader(
dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True
) )

52
tests/length_cdf.py Normal file
View File

@@ -0,0 +1,52 @@
# coding=utf-8
# Calculates the distribution of the input lengths in the dataset.
# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
from collections import defaultdict
from typing import Optional
import fire
from tqdm import tqdm
from llmtuner.data import get_dataset
from llmtuner.hparams import get_train_args
from llmtuner.model import load_model_and_tokenizer
def length_cdf(
model_name_or_path: str,
dataset: Optional[str] = "alpaca_en",
dataset_dir: Optional[str] = "data",
template: Optional[str] = "default",
interval: Optional[int] = 1000,
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
dict(
stage="sft",
model_name_or_path=model_name_or_path,
dataset=dataset,
dataset_dir=dataset_dir,
template=template,
cutoff_len=1_000_000,
output_dir="dummy_dir",
overwrite_cache=True,
)
)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
total_num = len(trainset)
length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"]):
length_dict[len(sample) // interval * interval] += 1
length_tuples = list(length_dict.items())
length_tuples.sort()
count_accu, prob_accu = 0, 0
for length, count in length_tuples:
count_accu += count
prob_accu += count / total_num * 100
print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
if __name__ == "__main__":
fire.Fire(length_cdf)

115
tests/llama_pro.py Normal file
View File

@@ -0,0 +1,115 @@
# coding=utf-8
# Performs block expansion for LLaMA, Mistral or Qwen1.5 models.
# Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
# Inspired by: https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
import json
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Optional
import fire
import torch
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
shard_checkpoint,
)
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
def change_name(name: str, old_index: int, new_index: int) -> str:
return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
def block_expansion(
model_name_or_path: str,
output_dir: str,
num_expand: int,
shard_size: Optional[str] = "2GB",
save_safetensors: Optional[bool] = False,
):
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
num_layers = getattr(config, "num_hidden_layers")
setattr(config, "num_hidden_layers", num_layers + num_expand)
config.save_pretrained(output_dir)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.save_pretrained(output_dir)
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) # load the original one
if save_safetensors:
setattr(config, "tie_word_embeddings", False) # safetensors does not allow shared weights
model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
config=config,
torch_dtype="auto",
trust_remote_code=True,
low_cpu_mem_usage=True,
)
state_dict = model.state_dict()
if num_layers % num_expand != 0:
raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
split = num_layers // num_expand
layer_cnt = 0
output_state_dict = OrderedDict()
for i in range(num_layers):
for key, value in state_dict.items():
if ".{:d}.".format(i) in key:
output_state_dict[change_name(key, i, layer_cnt)] = value
print("Add layer {} copied from layer {}".format(layer_cnt, i))
layer_cnt += 1
if (i + 1) % split == 0:
for key, value in state_dict.items():
if ".{:d}.".format(i) in key:
if "down_proj" in key or "o_proj" in key:
output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
else:
output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
print("Add layer {} expanded from layer {}".format(layer_cnt, i))
layer_cnt += 1
for key, value in state_dict.items():
if key not in output_state_dict:
output_state_dict[key] = value
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(output_state_dict, max_shard_size=shard_size, weights_name=weights_name)
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
if save_safetensors:
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(output_dir, shard_file))
if index is None:
print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
else:
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
json.dump(index, f, indent=2, sort_keys=True)
print("Model weights saved in {}".format(output_dir))
print("Fine-tune this model with:")
print(" --model_name_or_path {} \\".format(output_dir))
print(" --finetuning_type freeze \\")
print(" --name_module_trainable all \\")
print(" --num_layer_trainable {} \\".format(num_expand))
print(" --use_llama_pro")
if __name__ == "__main__":
fire.Fire(block_expansion)

View File

@@ -1,6 +1,6 @@
# coding=utf-8 # coding=utf-8
# Converts the Baichuan2-7B model in the same format as LLaMA2-7B. # Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output --shard_size 10GB # Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py # Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied # Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
@@ -76,7 +76,9 @@ def save_config(input_dir: str, output_dir: str):
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_baichuan2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False): def llamafy_baichuan2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
):
try: try:
os.makedirs(output_dir, exist_ok=False) os.makedirs(output_dir, exist_ok=False)
except Exception as e: except Exception as e:

View File

@@ -1,6 +1,6 @@
# coding=utf-8 # coding=utf-8
# Converts the InternLM2 model in the same format as LLaMA2. # Converts the InternLM2 model in the same format as LLaMA2.
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB # Usage: python llamafy_internlm2.py --input_dir input --output_dir output
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later. # Warning: We have found that the converted model cannot infer correctly. It will be fixed later.
import json import json
@@ -98,7 +98,9 @@ def save_config(input_dir: str, output_dir: str):
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_internlm2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False): def llamafy_internlm2(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
):
try: try:
os.makedirs(output_dir, exist_ok=False) os.makedirs(output_dir, exist_ok=False)
except Exception as e: except Exception as e:

View File

@@ -1,6 +1,6 @@
# coding=utf-8 # coding=utf-8
# Converts the Qwen models in the same format as LLaMA2. # Converts the Qwen models in the same format as LLaMA2.
# Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB # Usage: python llamafy_qwen.py --input_dir input --output_dir output
# Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied # Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
import json import json
@@ -128,7 +128,9 @@ def save_config(input_dir: str, output_dir: str, torch_dtype: str):
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
def llamafy_qwen(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False): def llamafy_qwen(
input_dir: str, output_dir: str, shard_size: Optional[str] = "2GB", save_safetensors: Optional[bool] = False
):
try: try:
os.makedirs(output_dir, exist_ok=False) os.makedirs(output_dir, exist_ok=False)
except Exception as e: except Exception as e:

View File

@@ -26,7 +26,7 @@ class Shell(nn.Module):
def unwrap_model(model: nn.Module, pattern=".base_layer") -> None: def unwrap_model(model: nn.Module, pattern=".base_layer") -> None:
for name in set([k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k]): # noqa: C403 for name in {k.split(pattern)[0] for k, _ in model.named_modules() if pattern in k}:
parent_name = ".".join(name.split(".")[:-1]) parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1] child_name = name.split(".")[-1]
parent_module = model.get_submodule(parent_name) parent_module = model.get_submodule(parent_name)

57
tests/test_toolcall.py Normal file
View File

@@ -0,0 +1,57 @@
import json
from typing import Sequence
from openai import OpenAI
from transformers.utils.versions import require_version
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
grade_to_score = {"A": 4, "B": 3, "C": 2}
total_score, total_hour = 0, 0
for grade, hour in zip(grades, hours):
total_score += grade_to_score[grade] * hour
total_hour += hour
return total_score / total_hour
tool_map = {"calculate_gpa": calculate_gpa}
if __name__ == "__main__":
client = OpenAI(
api_key="0",
base_url="http://localhost:8000/v1",
)
tools = [
{
"type": "function",
"function": {
"name": "calculate_gpa",
"description": "Calculate the Grade Point Average (GPA) based on grades and credit hours",
"parameters": {
"type": "object",
"properties": {
"grades": {"type": "array", "items": {"type": "string"}, "description": "The grades"},
"hours": {"type": "array", "items": {"type": "integer"}, "description": "The credit hours"},
},
"required": ["grades", "hours"],
},
},
}
]
messages = []
messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
tool_call = result.choices[0].message.tool_calls[0].function
name, arguments = tool_call.name, json.loads(tool_call.arguments)
messages.append(
{"role": "function", "content": json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)}
)
tool_result = tool_map[name](**arguments)
messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)})
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
print(result.choices[0].message.content)
# Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665.