Compare commits
149 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec334f5891 | ||
|
|
885efe772e | ||
|
|
64fc9ba678 | ||
|
|
989eccd286 | ||
|
|
f0766a2ab0 | ||
|
|
178b85ff9a | ||
|
|
68dd1ef121 | ||
|
|
b222cffe98 | ||
|
|
b4f1ab93d1 | ||
|
|
f2e139f5cd | ||
|
|
a9cbca1604 | ||
|
|
3a30ce6c16 | ||
|
|
48ec5355f9 | ||
|
|
11859bc322 | ||
|
|
28c67a5be8 | ||
|
|
44fe93e9b0 | ||
|
|
09a1681b63 | ||
|
|
f5ba2190fb | ||
|
|
14a38b5069 | ||
|
|
f23e5b602a | ||
|
|
857696ed9c | ||
|
|
2084133058 | ||
|
|
f7f0c3070e | ||
|
|
46235aa514 | ||
|
|
2eb65d21ac | ||
|
|
37a0d62a82 | ||
|
|
21ac46e439 | ||
|
|
ba3e8ba20c | ||
|
|
2c48e798ca | ||
|
|
4e40f5b62b | ||
|
|
2a8892b785 | ||
|
|
ee3b33ff03 | ||
|
|
b2c3001f8e | ||
|
|
6cfe1e1ac2 | ||
|
|
52326870e4 | ||
|
|
217fde0918 | ||
|
|
065021d82a | ||
|
|
4bb643e685 | ||
|
|
b77c745b1a | ||
|
|
7d13501b94 | ||
|
|
ac74639b32 | ||
|
|
12fa56ae68 | ||
|
|
f11b863f4b | ||
|
|
f3e4b72957 | ||
|
|
8d52fb46ca | ||
|
|
dab8f45033 | ||
|
|
bff8b02543 | ||
|
|
2406200914 | ||
|
|
db06fcfc84 | ||
|
|
93b9f74e9f | ||
|
|
33ec844f76 | ||
|
|
0f727b393e | ||
|
|
7da2aad6ee | ||
|
|
6f09f50d02 | ||
|
|
5919832059 | ||
|
|
f7635c1afc | ||
|
|
c762168ed0 | ||
|
|
67a46e553f | ||
|
|
e406f37b54 | ||
|
|
62fe877124 | ||
|
|
a0e682ba79 | ||
|
|
49e8a87383 | ||
|
|
b2764b49ca | ||
|
|
06b810de8f | ||
|
|
6da51565f5 | ||
|
|
1f69965239 | ||
|
|
af2d61178d | ||
|
|
6a955ccf4f | ||
|
|
c0658711ca | ||
|
|
d602f06882 | ||
|
|
1cb9a38ac2 | ||
|
|
47a1f73d0f | ||
|
|
142dd63b47 | ||
|
|
b1bd8370c2 | ||
|
|
215660c8da | ||
|
|
0cafe67efe | ||
|
|
ea83b3222b | ||
|
|
725087a04f | ||
|
|
d627ab4855 | ||
|
|
7d867e8df4 | ||
|
|
3d34d44497 | ||
|
|
a6f800b741 | ||
|
|
a003d1fa1e | ||
|
|
c2e84d4558 | ||
|
|
68330eab2a | ||
|
|
7070f3969d | ||
|
|
e4727ab155 | ||
|
|
280e7d97ad | ||
|
|
31e3805fb8 | ||
|
|
ef248dbe15 | ||
|
|
6a61b4b638 | ||
|
|
4b1473502f | ||
|
|
bf211d818d | ||
|
|
27dd87c890 | ||
|
|
8659084ab0 | ||
|
|
e1c9dcea93 | ||
|
|
171339ab17 | ||
|
|
8542ba5c69 | ||
|
|
97b74d328b | ||
|
|
3198a7e5f4 | ||
|
|
a2d08ce961 | ||
|
|
bd8ea09479 | ||
|
|
6d0d46c7fb | ||
|
|
820540780a | ||
|
|
f74d600497 | ||
|
|
94fec9f50e | ||
|
|
e387a50475 | ||
|
|
5c4248a29c | ||
|
|
f22886e2b6 | ||
|
|
33af3cbf37 | ||
|
|
728dfb1be7 | ||
|
|
e49f7f1afe | ||
|
|
21a454fa6c | ||
|
|
22c6c27f78 | ||
|
|
aecbb43096 | ||
|
|
fa53fd2db2 | ||
|
|
1c150995ae | ||
|
|
6c5d8f089e | ||
|
|
dd623325e8 | ||
|
|
e8a375c8f2 | ||
|
|
386d85ae72 | ||
|
|
ebb3901b05 | ||
|
|
20130b486c | ||
|
|
73c48d0463 | ||
|
|
f7cecd20e3 | ||
|
|
2bc64a7636 | ||
|
|
9564ddbb48 | ||
|
|
28062c71b5 | ||
|
|
35d1921081 | ||
|
|
4fbdf18c70 | ||
|
|
5e07ab01f0 | ||
|
|
fac465a21e | ||
|
|
e145a2ce0c | ||
|
|
dc68c313ee | ||
|
|
95c0d9ab24 | ||
|
|
46a718f339 | ||
|
|
496ba46960 | ||
|
|
43ae0aca1d | ||
|
|
b8574c1b82 | ||
|
|
32f8b1082b | ||
|
|
6443fef31a | ||
|
|
14c3795a7d | ||
|
|
3d9e2de573 | ||
|
|
0ca36a0f8d | ||
|
|
3e5555502a | ||
|
|
fbf5b5e0a9 | ||
|
|
3305e66f8c | ||
|
|
e19a44c12b | ||
|
|
8b0e6b9d1b |
240
README.md
240
README.md
@@ -1,59 +1,77 @@
|
||||
# LLaMA Efficient Tuning
|
||||
# LLaMA Factory: Training and Evaluating Large Language Models with Minimal Effort
|
||||
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/stargazers)
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/commits/main)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/pulls)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/e73gccsSd)
|
||||
|
||||
👋 Join our [WeChat](assets/wechat.jpg).
|
||||
|
||||
\[ English | [中文](README_zh.md) \]
|
||||
|
||||
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
|
||||
|
||||
Launch **LLaMA Board** via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet)
|
||||
|
||||
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
|
||||
|
||||
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
|
||||
|
||||
## Changelog
|
||||
|
||||
[23/09/10] Now we support using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs (experimental feature).
|
||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`.
|
||||
|
||||
[23/08/18] Now we support **resuming training**, upgrade `transformers` to `4.31.0` to enjoy this feature.
|
||||
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
||||
|
||||
[23/08/12] Now we support **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
|
||||
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
|
||||
|
||||
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models.
|
||||
[23/09/10] We supported using **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
||||
|
||||
[23/07/31] Now we support **dataset streaming**. Try `--streaming` and `--max_steps 10000` arguments to load your dataset in streaming mode.
|
||||
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
|
||||
|
||||
[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
|
||||
[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models.
|
||||
|
||||
[23/07/18] Now we develop an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
||||
[23/07/31] We supported **dataset streaming**. Try `--streaming` and `--max_steps 10000` arguments to load your dataset in streaming mode.
|
||||
|
||||
[23/07/09] Now we release **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
|
||||
[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
|
||||
|
||||
[23/06/29] We provide a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
|
||||
[23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
||||
|
||||
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
|
||||
[23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
|
||||
|
||||
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
|
||||
[23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
|
||||
|
||||
[23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
|
||||
|
||||
[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
|
||||
|
||||
## Supported Models
|
||||
|
||||
| Model | Model size | Default module | Template |
|
||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
|
||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
||||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
|
||||
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
|
||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
|
||||
> [!NOTE]
|
||||
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
|
||||
>
|
||||
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
|
||||
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
|
||||
|
||||
Please refer to [template.py](src/llmtuner/extras/template.py) for a full list of models we supported.
|
||||
|
||||
## Supported Training Approaches
|
||||
|
||||
@@ -70,38 +88,61 @@
|
||||
|
||||
## Provided Datasets
|
||||
|
||||
- For pre-training:
|
||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- For supervised fine-tuning:
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- For reward modeling or DPO training:
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
<details><summary>Pre-training datasets</summary>
|
||||
|
||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||
- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
|
||||
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
|
||||
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
|
||||
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>Supervised fine-tuning datasets</summary>
|
||||
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
|
||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>Preference datasets</summary>
|
||||
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
|
||||
</details>
|
||||
|
||||
Please refer to [data/README.md](data/README.md) for details.
|
||||
|
||||
@@ -117,9 +158,9 @@ huggingface-cli login
|
||||
- Python 3.8+ and PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||
- sentencepiece, protobuf and tiktoken
|
||||
- jieba, rouge-chinese and nltk (used at evaluation)
|
||||
- gradio and matplotlib (used in web_demo.py)
|
||||
- uvicorn, fastapi and sse-starlette (used in api_demo.py)
|
||||
- jieba, rouge-chinese and nltk (used at evaluation and predict)
|
||||
- gradio and matplotlib (used in web UI)
|
||||
- uvicorn, fastapi and sse-starlette (used in API)
|
||||
|
||||
And **powerful GPUs**!
|
||||
|
||||
@@ -127,7 +168,7 @@ And **powerful GPUs**!
|
||||
|
||||
### Data Preparation (optional)
|
||||
|
||||
Please refer to `data/example_dataset` for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
|
||||
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
|
||||
|
||||
> [!NOTE]
|
||||
> Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
|
||||
@@ -135,10 +176,10 @@ Please refer to `data/example_dataset` for checking the details about the format
|
||||
### Dependence Installation (optional)
|
||||
|
||||
```bash
|
||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||
conda create -n llama_etuning python=3.10
|
||||
conda activate llama_etuning
|
||||
cd LLaMA-Efficient-Tuning
|
||||
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
||||
conda create -n llama_factory python=3.10
|
||||
conda activate llama_factory
|
||||
cd LLaMA-Factory
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
@@ -148,17 +189,6 @@ If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you wi
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
### All-in-one Web UI
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
||||
```
|
||||
|
||||
We strongly recommend using the all-in-one Web UI for newcomers since it can also generate training scripts **automatically**.
|
||||
|
||||
> [!WARNING]
|
||||
> Currently the web UI only supports training on **a single GPU**.
|
||||
|
||||
### Train on a single GPU
|
||||
|
||||
> [!IMPORTANT]
|
||||
@@ -365,7 +395,7 @@ python src/export_model.py \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_export
|
||||
--export_dir path_to_export
|
||||
```
|
||||
|
||||
### API Demo
|
||||
@@ -401,26 +431,21 @@ python src/web_demo.py \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### Evaluation (BLEU and ROUGE_CHINESE)
|
||||
### Evaluation
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_eval \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_eval_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate
|
||||
--template vanilla \
|
||||
--task mmlu \
|
||||
--split test \
|
||||
--lang en \
|
||||
--n_shot 5 \
|
||||
--batch_size 4
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
|
||||
|
||||
### Predict
|
||||
|
||||
```bash
|
||||
@@ -438,40 +463,39 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
|
||||
|
||||
## Projects using LLaMA Factory
|
||||
|
||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||
|
||||
## License
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the model licenses to use the corresponding model weights:
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
|
||||
- [LLaMA-2](https://ai.meta.com/llama/license/)
|
||||
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
|
||||
- [Falcon](LICENSE)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||
- [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/resolve/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
|
||||
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
|
||||
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
||||
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
||||
|
||||
## Citation
|
||||
|
||||
If this work is helpful, please kindly cite as:
|
||||
|
||||
```bibtex
|
||||
@Misc{llama-efficient-tuning,
|
||||
title = {LLaMA Efficient Tuning},
|
||||
@Misc{llama-factory,
|
||||
title = {LLaMA Factory},
|
||||
author = {hiyouga},
|
||||
howpublished = {\url{https://github.com/hiyouga/LLaMA-Efficient-Tuning}},
|
||||
howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora) and [OpenChatKit](https://github.com/togethercomputer/OpenChatKit). Thanks for their wonderful works.
|
||||
This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
|
||||
|
||||
## Star History
|
||||
|
||||

|
||||

|
||||
|
||||
234
README_zh.md
234
README_zh.md
@@ -1,30 +1,44 @@
|
||||
# LLaMA Efficient Tuning
|
||||
# LLaMA Factory: 轻松的大模型训练与评估
|
||||
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/stargazers)
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/commits/main)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/pulls)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/e73gccsSd)
|
||||
|
||||
👋 加入我们的[微信群](assets/wechat.jpg)。
|
||||
|
||||
\[ [English](README.md) | 中文 \]
|
||||
|
||||
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
|
||||
|
||||
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 **LLaMA Board**。(该界面目前仅支持单卡训练)
|
||||
|
||||
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
|
||||
|
||||
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
|
||||
|
||||
## 更新日志
|
||||
|
||||
[23/09/10] 现在我们支持了 LLaMA 模型的 **[FlashAttention](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2(实验性功能)。
|
||||
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。
|
||||
|
||||
[23/08/18] 现在我们支持了**训练状态恢复**,请将 `transformers` 升级至 `4.31.0` 以启用此功能。
|
||||
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
||||
|
||||
[23/08/12] 现在我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
|
||||
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
|
||||
|
||||
[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-训练)。
|
||||
[23/09/10] 我们针对 LLaMA 模型支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。
|
||||
|
||||
[23/07/31] 现在我们支持了**数据流式加载**。请尝试使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。
|
||||
[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
|
||||
|
||||
[23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。使用方法请参阅[此示例](#dpo-训练)。
|
||||
|
||||
[23/07/31] 我们支持了**数据流式加载**。请尝试使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。
|
||||
|
||||
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
|
||||
|
||||
[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
|
||||
[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
|
||||
|
||||
[23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
|
||||
|
||||
@@ -32,28 +46,32 @@
|
||||
|
||||
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。
|
||||
|
||||
[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
||||
[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
||||
|
||||
## 模型
|
||||
|
||||
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
|
||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
||||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
|
||||
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
|
||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
|
||||
> [!NOTE]
|
||||
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
|
||||
>
|
||||
> 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用对应的模板。
|
||||
> 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用**对应的模板**。
|
||||
|
||||
项目所支持模型的完整列表请参阅 [template.py](src/llmtuner/extras/template.py)。
|
||||
|
||||
## 训练方法
|
||||
|
||||
@@ -70,40 +88,63 @@
|
||||
|
||||
## 数据集
|
||||
|
||||
- 用于预训练:
|
||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- 用于指令监督微调:
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- 用于训练奖励模型或 DPO 训练:
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
<details><summary>预训练数据集</summary>
|
||||
|
||||
使用方法请参考 [data/README.md](data/README_zh.md) 文件。
|
||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||
- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
|
||||
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
|
||||
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
|
||||
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>指令微调数据集</summary>
|
||||
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
|
||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>偏好数据集</summary>
|
||||
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
|
||||
</details>
|
||||
|
||||
使用方法请参考 [data/README_zh.md](data/README_zh.md) 文件。
|
||||
|
||||
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
|
||||
|
||||
@@ -117,7 +158,7 @@ huggingface-cli login
|
||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
||||
- sentencepiece, protobuf 和 tiktoken
|
||||
- jieba, rouge-chinese 和 nltk (用于评估)
|
||||
- jieba, rouge-chinese 和 nltk (用于评估及预测)
|
||||
- gradio 和 matplotlib (用于网页端交互)
|
||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
||||
|
||||
@@ -127,18 +168,18 @@ huggingface-cli login
|
||||
|
||||
### 数据准备(可跳过)
|
||||
|
||||
关于数据集文件的格式,请参考 `data/example_dataset` 文件夹的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
|
||||
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
|
||||
|
||||
> [!NOTE]
|
||||
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README.md`。
|
||||
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README_zh.md`。
|
||||
|
||||
### 环境搭建(可跳过)
|
||||
|
||||
```bash
|
||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||
conda create -n llama_etuning python=3.10
|
||||
conda activate llama_etuning
|
||||
cd LLaMA-Efficient-Tuning
|
||||
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
||||
conda create -n llama_factory python=3.10
|
||||
conda activate llama_factory
|
||||
cd LLaMA-Factory
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
@@ -148,17 +189,6 @@ pip install -r requirements.txt
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
### 浏览器一体化界面
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
||||
```
|
||||
|
||||
我们极力推荐新手使用浏览器一体化界面,因为它还可以**自动**生成运行所需的命令行脚本。
|
||||
|
||||
> [!WARNING]
|
||||
> 目前网页 UI 仅支持**单卡训练**。
|
||||
|
||||
### 单 GPU 训练
|
||||
|
||||
> [!IMPORTANT]
|
||||
@@ -356,7 +386,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||
|
||||
</details>
|
||||
|
||||
### 导出微调后的模型
|
||||
### 导出微调后的完整模型
|
||||
|
||||
```bash
|
||||
python src/export_model.py \
|
||||
@@ -364,7 +394,7 @@ python src/export_model.py \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_export
|
||||
--export_dir path_to_export
|
||||
```
|
||||
|
||||
### API 服务
|
||||
@@ -400,26 +430,21 @@ python src/web_demo.py \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### 指标评估(BLEU 分数和汉语 ROUGE 分数)
|
||||
### 模型评估
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_eval \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_eval_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate
|
||||
--template vanilla \
|
||||
--task ceval \
|
||||
--split validation \
|
||||
--lang zh \
|
||||
--n_shot 5 \
|
||||
--batch_size 4
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> 我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
|
||||
|
||||
### 模型预测
|
||||
|
||||
```bash
|
||||
@@ -437,40 +462,39 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
|
||||
|
||||
## 使用了 LLaMA Factory 的项目
|
||||
|
||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||
|
||||
## 协议
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
|
||||
- [LLaMA-2](https://ai.meta.com/llama/license/)
|
||||
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
|
||||
- [Falcon](LICENSE)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||
- [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/resolve/main/Baichuan%202%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
|
||||
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
|
||||
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
||||
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
||||
|
||||
## 引用
|
||||
|
||||
如果您觉得此项目有帮助,请考虑以下列格式引用
|
||||
|
||||
```bibtex
|
||||
@Misc{llama-efficient-tuning,
|
||||
title = {LLaMA Efficient Tuning},
|
||||
@Misc{llama-factory,
|
||||
title = {LLaMA Factory},
|
||||
author = {hiyouga},
|
||||
howpublished = {\url{https://github.com/hiyouga/LLaMA-Efficient-Tuning}},
|
||||
howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
|
||||
## 致谢
|
||||
|
||||
本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora) 和 [OpenChatKit](https://github.com/togethercomputer/OpenChatKit),感谢以上诸位作者的付出。
|
||||
本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora) 和 [FastChat](https://github.com/lm-sys/FastChat),感谢以上诸位作者的付出。
|
||||
|
||||
## Star History
|
||||
|
||||

|
||||

|
||||
|
||||
101
data/README.md
101
data/README.md
@@ -2,31 +2,106 @@ If you are using a custom dataset, please provide your dataset definition in the
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
|
||||
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)",
|
||||
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
|
||||
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
|
||||
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
|
||||
"ranking": "whether the examples contains ranked responses or not. (default: false)",
|
||||
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
|
||||
"subset": "the name of the subset. (optional, default: None)",
|
||||
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
|
||||
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
||||
"columns": {
|
||||
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
|
||||
"query": "the name of the column in the datasets containing the queries. (default: input)",
|
||||
"response": "the name of the column in the datasets containing the responses. (default: output)",
|
||||
"history": "the name of the column in the datasets containing the history of chat. (default: None)"
|
||||
"prompt": "the column name in the dataset containing the prompts. (default: instruction, for alpaca)",
|
||||
"query": "the column name in the dataset containing the queries. (default: input, for alpaca)",
|
||||
"response": "the column name in the dataset containing the responses. (default: output, for alpaca)",
|
||||
"history": "the column name in the dataset containing the histories. (default: None, for alpaca)",
|
||||
"messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)",
|
||||
"role": "the key in the message represents the identity. (default: from, for sharegpt)",
|
||||
"content": "the key in the message represents the content. (default: value, for sharegpt)"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
where the `prompt` and `response` columns should contain non-empty values. The `query` column will be concatenated with the `prompt` column and used as input for the model. The `history` column should contain a list where each element is a string tuple representing a query-response pair.
|
||||
Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
|
||||
|
||||
For datasets used in reward modeling or DPO training, the `response` column should be a string list, with the preferred answers appearing first, for example:
|
||||
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "user instruction (required)",
|
||||
"input": "user input (optional)",
|
||||
"output": "model response (required)",
|
||||
"history": [
|
||||
["user instruction in the first round (optional)", "model response in the first round (optional)"],
|
||||
["user instruction in the second round (optional)", "model response in the second round (optional)"]
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
where the `prompt` and `response` columns should contain non-empty values, represent instruction and response respectively. The `query` column will be concatenated with the `prompt` column and used as input for the model.
|
||||
|
||||
The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**.
|
||||
|
||||
For the pre-training datasets, only the `prompt` column will be used for training.
|
||||
|
||||
For the preference datasets, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example:
|
||||
|
||||
```json
|
||||
{
|
||||
"instruction": "Question",
|
||||
"input": "",
|
||||
"instruction": "user instruction",
|
||||
"input": "user input",
|
||||
"output": [
|
||||
"Chosen answer",
|
||||
"Rejected answer"
|
||||
"chosen answer",
|
||||
"rejected answer"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The dataset in sharegpt format should follow the below format:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "user instruction"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "model response"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"role": "from",
|
||||
"content": "value"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
where the `messages` column should be a list whose length is even, and follow the `u/a/u/a/u/a` order.
|
||||
|
||||
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.
|
||||
|
||||
@@ -1,32 +1,107 @@
|
||||
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中以如下格式提供您的数据集定义。
|
||||
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中按照以下格式提供数据集定义。
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"hf_hub_url": "HuggingFace上的项目地址(若指定,则忽略下列三个参数)",
|
||||
"hf_hub_url": "Hugging Face 上的项目地址(若指定,则忽略下列三个参数)",
|
||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||
"file_sha1": "数据集文件的SHA-1哈希值(可选)",
|
||||
"ranking": "数据集是否包含排序后的回答(默认:false)",
|
||||
"file_sha1": "数据集文件的SHA-1哈希值(可选,留空不影响训练)",
|
||||
"subset": "数据集子集的名称(可选,默认:None)",
|
||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
||||
"columns": {
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||
"query": "数据集代表请求的表头名称(默认:input)",
|
||||
"response": "数据集代表回答的表头名称(默认:output)",
|
||||
"history": "数据集代表历史对话的表头名称(默认:None)"
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction,用于 alpaca 格式)",
|
||||
"query": "数据集代表请求的表头名称(默认:input,用于 alpaca 格式)",
|
||||
"response": "数据集代表回答的表头名称(默认:output,用于 alpaca 格式)",
|
||||
"history": "数据集代表历史对话的表头名称(默认:None,用于 alpaca 格式)",
|
||||
"messages": "数据集代表消息列表的表头名称(默认:conversations,用于 sharegpt 格式)",
|
||||
"role": "消息中代表发送者身份的键名(默认:from,用于 sharegpt 格式)",
|
||||
"content": "消息中代表文本内容的键名(默认:value,用于 sharegpt 格式)"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
其中 `prompt` 和 `response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复。
|
||||
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
|
||||
|
||||
对于训练奖励模型或 DPO 训练的数据集,`response` 列应当是一个字符串列表,排在前面的代表更优的答案,例如:
|
||||
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "用户指令(必填)",
|
||||
"input": "用户输入(选填)",
|
||||
"output": "模型回答(必填)",
|
||||
"history": [
|
||||
["第一轮指令(选填)", "第一轮回答(选填)"],
|
||||
["第二轮指令(选填)", "第二轮回答(选填)"]
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
其中 `prompt` 和 `response` 列应当是非空的字符串,分别代表用户指令和模型回答。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。
|
||||
|
||||
`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**均会被用于训练**。
|
||||
|
||||
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
|
||||
|
||||
对于偏好数据集,`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如:
|
||||
|
||||
```json
|
||||
{
|
||||
"instruction": "Question",
|
||||
"input": "",
|
||||
"instruction": "用户指令",
|
||||
"input": "用户输入",
|
||||
"output": [
|
||||
"Chosen answer",
|
||||
"Rejected answer"
|
||||
"优质回答",
|
||||
"劣质回答"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
而 sharegpt 格式的数据集按照以下方式组织:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "用户指令"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "模型回答"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"role": "from",
|
||||
"content": "value"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
其中 `messages` 列必须为偶数长度的列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
|
||||
|
||||
预训练数据集和偏好数据集尚不支持 sharegpt 格式。
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import datasets
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
_DESCRIPTION = "BELLE multiturn chat dataset."
|
||||
@@ -23,7 +22,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
def _info(self):
|
||||
features = datasets.Features({
|
||||
"instruction": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
@@ -37,7 +36,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
file_path = dl_manager.download(_URL)
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
@@ -48,7 +47,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat with history
|
||||
def _generate_examples(self, filepath: str):
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
for key, row in enumerate(f):
|
||||
data = json.loads(row)
|
||||
|
||||
@@ -3,7 +3,7 @@ import datasets
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
_DESCRIPTION = "An example of dataset for LLaMA."
|
||||
_DESCRIPTION = "An example of dataset."
|
||||
_CITATION = ""
|
||||
_HOMEPAGE = ""
|
||||
_LICENSE = ""
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import datasets
|
||||
from typing import Any, Dict, List
|
||||
from typing import List
|
||||
|
||||
|
||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness for ChatGLM."
|
||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
||||
_CITATION = ""
|
||||
_HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf"
|
||||
_LICENSE = "mit"
|
||||
@@ -42,7 +42,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
file_path = dl_manager.download_and_extract(_URLS)
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
@@ -59,7 +59,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
key = 0
|
||||
for filepath in filepaths:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import datasets
|
||||
from typing import Any, Dict, List
|
||||
from typing import List
|
||||
|
||||
|
||||
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
|
||||
@@ -21,15 +21,13 @@ _LICENSE = "cc-by-nc-4.0"
|
||||
_BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl"
|
||||
|
||||
|
||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
def _info(self):
|
||||
features = datasets.Features({
|
||||
"instruction": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
||||
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
|
||||
})
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
@@ -39,8 +37,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(9)] # multiple shards
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
@@ -50,7 +48,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
for filepath in filepaths:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
for row in f:
|
||||
@@ -58,19 +56,16 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
data = json.loads(row)
|
||||
except:
|
||||
continue
|
||||
key = data["id"]
|
||||
content = data["data"]
|
||||
key: int = data["id"]
|
||||
content: List[str] = data["data"]
|
||||
if len(content) % 2 == 1:
|
||||
content.pop(-1)
|
||||
if len(content) < 2:
|
||||
continue
|
||||
|
||||
query = content[-2]
|
||||
response = content[-1]
|
||||
history = [[content[2*i], content[2*i+1]] for i in range(len(content) // 2 - 1)]
|
||||
|
||||
conversations = [{
|
||||
"from": "human" if i % 2 == 0 else "gpt",
|
||||
"value": content[i]
|
||||
} for i in range(len(content))]
|
||||
yield key, {
|
||||
"instruction": query,
|
||||
"output": response,
|
||||
"history": history
|
||||
"conversations": conversations
|
||||
}
|
||||
|
||||
166
evaluation/ceval/ceval.py
Normal file
166
evaluation/ceval/ceval.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@article{huang2023ceval,
|
||||
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
|
||||
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian},
|
||||
journal={arXiv preprint arXiv:2305.08322},
|
||||
year={2023}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
C-Eval is a comprehensive Chinese evaluation suite for foundation models. It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels.
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "https://cevalbenchmark.com"
|
||||
|
||||
_LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License"
|
||||
|
||||
_URL = "ceval.zip"
|
||||
|
||||
task_list = [
|
||||
"computer_network",
|
||||
"operating_system",
|
||||
"computer_architecture",
|
||||
"college_programming",
|
||||
"college_physics",
|
||||
"college_chemistry",
|
||||
"advanced_mathematics",
|
||||
"probability_and_statistics",
|
||||
"discrete_mathematics",
|
||||
"electrical_engineer",
|
||||
"metrology_engineer",
|
||||
"high_school_mathematics",
|
||||
"high_school_physics",
|
||||
"high_school_chemistry",
|
||||
"high_school_biology",
|
||||
"middle_school_mathematics",
|
||||
"middle_school_biology",
|
||||
"middle_school_physics",
|
||||
"middle_school_chemistry",
|
||||
"veterinary_medicine",
|
||||
"college_economics",
|
||||
"business_administration",
|
||||
"marxism",
|
||||
"mao_zedong_thought",
|
||||
"education_science",
|
||||
"teacher_qualification",
|
||||
"high_school_politics",
|
||||
"high_school_geography",
|
||||
"middle_school_politics",
|
||||
"middle_school_geography",
|
||||
"modern_chinese_history",
|
||||
"ideological_and_moral_cultivation",
|
||||
"logic",
|
||||
"law",
|
||||
"chinese_language_and_literature",
|
||||
"art_studies",
|
||||
"professional_tour_guide",
|
||||
"legal_professional",
|
||||
"high_school_chinese",
|
||||
"high_school_history",
|
||||
"middle_school_history",
|
||||
"civil_servant",
|
||||
"sports_science",
|
||||
"plant_protection",
|
||||
"basic_medicine",
|
||||
"clinical_medicine",
|
||||
"urban_and_rural_planner",
|
||||
"accountant",
|
||||
"fire_engineer",
|
||||
"environmental_impact_assessment_engineer",
|
||||
"tax_accountant",
|
||||
"physician",
|
||||
]
|
||||
|
||||
|
||||
class CevalConfig(datasets.BuilderConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(version=datasets.Version("1.0.0"), **kwargs)
|
||||
|
||||
|
||||
class Ceval(datasets.GeneratorBasedBuilder):
|
||||
BUILDER_CONFIGS = [
|
||||
CevalConfig(
|
||||
name=task_name,
|
||||
)
|
||||
for task_name in task_list
|
||||
]
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features(
|
||||
{
|
||||
"id": datasets.Value("int32"),
|
||||
"question": datasets.Value("string"),
|
||||
"A": datasets.Value("string"),
|
||||
"B": datasets.Value("string"),
|
||||
"C": datasets.Value("string"),
|
||||
"D": datasets.Value("string"),
|
||||
"answer": datasets.Value("string"),
|
||||
"explanation": datasets.Value("string"),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION,
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager):
|
||||
data_dir = dl_manager.download_and_extract(_URL)
|
||||
task_name = self.config.name
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "test", f"{task_name}_test.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.VALIDATION,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "val", f"{task_name}_val.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "dev", f"{task_name}_dev.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath):
|
||||
df = pd.read_csv(filepath, encoding="utf-8")
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
if "answer" not in instance.keys():
|
||||
instance["answer"] = ""
|
||||
if "explanation" not in instance.keys():
|
||||
instance["explanation"] = ""
|
||||
yield i, instance
|
||||
167
evaluation/cmmlu/cmmlu.py
Normal file
167
evaluation/cmmlu/cmmlu.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@article{li2023cmmlu,
|
||||
title={CMMLU: Measuring massive multitask language understanding in Chinese},
|
||||
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin},
|
||||
journal={arXiv preprint arXiv:2306.09212},
|
||||
year={2023}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge and reasoning abilities of LLMs within the Chinese language and cultural context.
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "https://github.com/haonan-li/CMMLU"
|
||||
|
||||
_LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License"
|
||||
|
||||
_URL = "cmmlu.zip"
|
||||
|
||||
task_list = [
|
||||
'agronomy',
|
||||
'anatomy',
|
||||
'ancient_chinese',
|
||||
'arts',
|
||||
'astronomy',
|
||||
'business_ethics',
|
||||
'chinese_civil_service_exam',
|
||||
'chinese_driving_rule',
|
||||
'chinese_food_culture',
|
||||
'chinese_foreign_policy',
|
||||
'chinese_history',
|
||||
'chinese_literature',
|
||||
'chinese_teacher_qualification',
|
||||
'clinical_knowledge',
|
||||
'college_actuarial_science',
|
||||
'college_education',
|
||||
'college_engineering_hydrology',
|
||||
'college_law',
|
||||
'college_mathematics',
|
||||
'college_medical_statistics',
|
||||
'college_medicine',
|
||||
'computer_science',
|
||||
'computer_security',
|
||||
'conceptual_physics',
|
||||
'construction_project_management',
|
||||
'economics',
|
||||
'education',
|
||||
'electrical_engineering',
|
||||
'elementary_chinese',
|
||||
'elementary_commonsense',
|
||||
'elementary_information_and_technology',
|
||||
'elementary_mathematics',
|
||||
'ethnology',
|
||||
'food_science',
|
||||
'genetics',
|
||||
'global_facts',
|
||||
'high_school_biology',
|
||||
'high_school_chemistry',
|
||||
'high_school_geography',
|
||||
'high_school_mathematics',
|
||||
'high_school_physics',
|
||||
'high_school_politics',
|
||||
'human_sexuality',
|
||||
'international_law',
|
||||
'journalism',
|
||||
'jurisprudence',
|
||||
'legal_and_moral_basis',
|
||||
'logical',
|
||||
'machine_learning',
|
||||
'management',
|
||||
'marketing',
|
||||
'marxist_theory',
|
||||
'modern_chinese',
|
||||
'nutrition',
|
||||
'philosophy',
|
||||
'professional_accounting',
|
||||
'professional_law',
|
||||
'professional_medicine',
|
||||
'professional_psychology',
|
||||
'public_relations',
|
||||
'security_study',
|
||||
'sociology',
|
||||
'sports_science',
|
||||
'traditional_chinese_medicine',
|
||||
'virology',
|
||||
'world_history',
|
||||
'world_religions',
|
||||
]
|
||||
|
||||
|
||||
class CMMLUConfig(datasets.BuilderConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(version=datasets.Version("1.0.1"), **kwargs)
|
||||
|
||||
|
||||
class CMMLU(datasets.GeneratorBasedBuilder):
|
||||
BUILDER_CONFIGS = [
|
||||
CMMLUConfig(
|
||||
name=task_name,
|
||||
)
|
||||
for task_name in task_list
|
||||
]
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features(
|
||||
{
|
||||
"question": datasets.Value("string"),
|
||||
"A": datasets.Value("string"),
|
||||
"B": datasets.Value("string"),
|
||||
"C": datasets.Value("string"),
|
||||
"D": datasets.Value("string"),
|
||||
"answer": datasets.Value("string"),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION,
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager):
|
||||
data_dir = dl_manager.download_and_extract(_URL)
|
||||
task_name = self.config.name
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(data_dir, f"test/{task_name}.csv"),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(data_dir, f"dev/{task_name}.csv"),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath):
|
||||
df = pd.read_csv(filepath, header=0, index_col=0, encoding="utf-8")
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
question = instance.pop("Question", "")
|
||||
answer = instance.pop("Answer", "")
|
||||
instance["question"] = question
|
||||
instance["answer"] = answer
|
||||
yield i, instance
|
||||
167
evaluation/mmlu/mmlu.py
Normal file
167
evaluation/mmlu/mmlu.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@article{hendryckstest2021,
|
||||
title={Measuring Massive Multitask Language Understanding},
|
||||
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
|
||||
journal={Proceedings of the International Conference on Learning Representations (ICLR)},
|
||||
year={2021}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021).
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "https://github.com/hendrycks/test"
|
||||
|
||||
_LICENSE = "MIT"
|
||||
|
||||
_URL = "mmlu.zip"
|
||||
|
||||
task_list = [
|
||||
"high_school_european_history",
|
||||
"business_ethics",
|
||||
"clinical_knowledge",
|
||||
"medical_genetics",
|
||||
"high_school_us_history",
|
||||
"high_school_physics",
|
||||
"high_school_world_history",
|
||||
"virology",
|
||||
"high_school_microeconomics",
|
||||
"econometrics",
|
||||
"college_computer_science",
|
||||
"high_school_biology",
|
||||
"abstract_algebra",
|
||||
"professional_accounting",
|
||||
"philosophy",
|
||||
"professional_medicine",
|
||||
"nutrition",
|
||||
"global_facts",
|
||||
"machine_learning",
|
||||
"security_studies",
|
||||
"public_relations",
|
||||
"professional_psychology",
|
||||
"prehistory",
|
||||
"anatomy",
|
||||
"human_sexuality",
|
||||
"college_medicine",
|
||||
"high_school_government_and_politics",
|
||||
"college_chemistry",
|
||||
"logical_fallacies",
|
||||
"high_school_geography",
|
||||
"elementary_mathematics",
|
||||
"human_aging",
|
||||
"college_mathematics",
|
||||
"high_school_psychology",
|
||||
"formal_logic",
|
||||
"high_school_statistics",
|
||||
"international_law",
|
||||
"high_school_mathematics",
|
||||
"high_school_computer_science",
|
||||
"conceptual_physics",
|
||||
"miscellaneous",
|
||||
"high_school_chemistry",
|
||||
"marketing",
|
||||
"professional_law",
|
||||
"management",
|
||||
"college_physics",
|
||||
"jurisprudence",
|
||||
"world_religions",
|
||||
"sociology",
|
||||
"us_foreign_policy",
|
||||
"high_school_macroeconomics",
|
||||
"computer_security",
|
||||
"moral_scenarios",
|
||||
"moral_disputes",
|
||||
"electrical_engineering",
|
||||
"astronomy",
|
||||
"college_biology",
|
||||
]
|
||||
|
||||
|
||||
class MMLUConfig(datasets.BuilderConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(version=datasets.Version("1.0.0"), **kwargs)
|
||||
|
||||
|
||||
class MMLU(datasets.GeneratorBasedBuilder):
|
||||
BUILDER_CONFIGS = [
|
||||
MMLUConfig(
|
||||
name=task_name,
|
||||
)
|
||||
for task_name in task_list
|
||||
]
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features(
|
||||
{
|
||||
"question": datasets.Value("string"),
|
||||
"A": datasets.Value("string"),
|
||||
"B": datasets.Value("string"),
|
||||
"C": datasets.Value("string"),
|
||||
"D": datasets.Value("string"),
|
||||
"answer": datasets.Value("string"),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION,
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager):
|
||||
data_dir = dl_manager.download_and_extract(_URL)
|
||||
task_name = self.config.name
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "test", f"{task_name}_test.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.VALIDATION,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "val", f"{task_name}_val.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "dev", f"{task_name}_dev.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath):
|
||||
df = pd.read_csv(filepath)
|
||||
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
||||
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
yield i, instance
|
||||
@@ -1,9 +1,10 @@
|
||||
torch>=1.13.1
|
||||
transformers>=4.30.0
|
||||
datasets>=2.12.0
|
||||
transformers>=4.31.0,<4.35.0
|
||||
datasets>=2.14.0
|
||||
accelerate>=0.21.0
|
||||
peft==0.4.0
|
||||
trl>=0.7.1
|
||||
peft>=0.6.0
|
||||
trl>=0.7.4
|
||||
gradio>=3.38.0,<4.0.0
|
||||
scipy
|
||||
sentencepiece
|
||||
protobuf
|
||||
@@ -11,9 +12,8 @@ tiktoken
|
||||
jieba
|
||||
rouge-chinese
|
||||
nltk
|
||||
gradio>=3.36.0
|
||||
uvicorn
|
||||
pydantic==1.10.11
|
||||
fastapi==0.95.1
|
||||
pydantic
|
||||
fastapi
|
||||
sse-starlette
|
||||
matplotlib
|
||||
|
||||
4
setup.py
4
setup.py
@@ -25,12 +25,12 @@ def main():
|
||||
version=get_version(),
|
||||
author="hiyouga",
|
||||
author_email="hiyouga" "@" "buaa.edu.cn",
|
||||
description="Easy-to-use fine-tuning framework using PEFT",
|
||||
description="Easy-to-use LLM fine-tuning framework",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
||||
license="Apache 2.0 License",
|
||||
url="https://github.com/hiyouga/LLaMA-Efficient-Tuning",
|
||||
url="https://github.com/hiyouga/LLaMA-Factory",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
python_requires=">=3.8.0",
|
||||
|
||||
@@ -6,8 +6,8 @@ from llmtuner import ChatModel, create_app
|
||||
def main():
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
print("Visit http://localhost:8000/docs for API document.")
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import readline
|
||||
from llmtuner import ChatModel
|
||||
|
||||
|
||||
|
||||
10
src/evaluate.py
Normal file
10
src/evaluate.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from llmtuner import Evaluator
|
||||
|
||||
|
||||
def main():
|
||||
evaluator = Evaluator()
|
||||
evaluator.eval()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,9 +1,10 @@
|
||||
# Level: api, webui > chat > tuner > dsets > extras, hparams
|
||||
# Level: api, webui > chat, eval > tuner > dsets > extras, hparams
|
||||
|
||||
from llmtuner.api import create_app
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.eval import Evaluator
|
||||
from llmtuner.tuner import export_model, run_exp
|
||||
from llmtuner.webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
__version__ = "0.1.8"
|
||||
__version__ = "0.2.2"
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import json
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from sse_starlette import EventSourceResponse
|
||||
from typing import List, Tuple
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.chat import ChatModel
|
||||
@@ -29,6 +31,13 @@ async def lifespan(app: FastAPI): # collects GPU memory
|
||||
torch_gc()
|
||||
|
||||
|
||||
def to_json(data: BaseModel) -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@@ -45,10 +54,10 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
query = request.messages[-1].content
|
||||
prev_messages = request.messages[:-1]
|
||||
@@ -62,13 +71,20 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
for i in range(0, len(prev_messages), 2):
|
||||
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
if request.stream:
|
||||
generate = predict(query, history, system, request)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
response, (prompt_length, response_length) = chat_model.chat(
|
||||
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
||||
query, history, system,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n
|
||||
)
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
@@ -77,13 +93,13 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
total_tokens=prompt_length+response_length
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=response),
|
||||
choices = [ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=choice),
|
||||
finish_reason=Finish.STOP
|
||||
)
|
||||
) for i, choice in enumerate(response)]
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
|
||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
@@ -92,10 +108,14 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield to_json(chunk)
|
||||
|
||||
for new_text in chat_model.stream_chat(
|
||||
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
||||
query, history, system,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens
|
||||
):
|
||||
if len(new_text) == 0:
|
||||
continue
|
||||
@@ -106,7 +126,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield to_json(chunk)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
@@ -114,7 +134,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
finish_reason=Finish.STOP
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield to_json(chunk)
|
||||
yield "[DONE]"
|
||||
|
||||
return app
|
||||
|
||||
@@ -20,9 +20,6 @@ class ModelCard(BaseModel):
|
||||
object: Optional[str] = "model"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Optional[str] = "owner"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = []
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
@@ -43,6 +40,7 @@ class DeltaMessage(BaseModel):
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
do_sample: Optional[bool] = True
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = 1
|
||||
|
||||
@@ -13,6 +13,7 @@ class ChatModel:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.model = dispatch_model(self.model)
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.system_prompt = data_args.system_prompt
|
||||
@@ -25,17 +26,17 @@ class ChatModel:
|
||||
**input_kwargs
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
system = system or self.system_prompt
|
||||
|
||||
prompt, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
||||
)
|
||||
prompt_length = len(prompt)
|
||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||
prompt_length = len(input_ids[0])
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", None)
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
@@ -46,11 +47,15 @@ class ChatModel:
|
||||
temperature=temperature or generating_args["temperature"],
|
||||
top_p=top_p or generating_args["top_p"],
|
||||
top_k=top_k or generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences or 1,
|
||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
pad_token_id=self.tokenizer.pad_token_id
|
||||
))
|
||||
|
||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||
generating_args["do_sample"] = True
|
||||
|
||||
if max_length:
|
||||
generating_args.pop("max_new_tokens", None)
|
||||
generating_args["max_length"] = max_length
|
||||
@@ -74,12 +79,16 @@ class ChatModel:
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[str, Tuple[int, int]]:
|
||||
) -> Tuple[List[str], Tuple[int, int]]:
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
|
||||
generation_output = self.model.generate(**gen_kwargs)
|
||||
outputs = generation_output.tolist()[0][prompt_length:]
|
||||
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
|
||||
response_length = len(outputs)
|
||||
generate_output = self.model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
response_length = 0
|
||||
for i in range(len(response_ids)):
|
||||
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
||||
response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i])
|
||||
|
||||
return response, (prompt_length, response_length)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||
|
||||
@@ -26,22 +26,23 @@ def get_dataset(
|
||||
|
||||
if dataset_attr.load_from == "hf_hub":
|
||||
data_path = dataset_attr.dataset_name
|
||||
data_name = dataset_attr.subset
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "script":
|
||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
data_name = dataset_attr.subset
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_path = None
|
||||
data_path, data_name = None, None
|
||||
data_files: List[str] = []
|
||||
|
||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory
|
||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory
|
||||
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
||||
if data_path is None:
|
||||
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
|
||||
else:
|
||||
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match."
|
||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file
|
||||
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
|
||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
||||
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
|
||||
else:
|
||||
@@ -53,27 +54,75 @@ def get_dataset(
|
||||
raise NotImplementedError
|
||||
|
||||
dataset = load_dataset(
|
||||
data_path,
|
||||
path=data_path,
|
||||
name=data_name,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
streaming=data_args.streaming,
|
||||
use_auth_token=True if model_args.use_auth_token else None
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=data_args.streaming
|
||||
)
|
||||
|
||||
if max_samples is not None:
|
||||
max_samples_temp = min(len(dataset), max_samples)
|
||||
dataset = dataset.select(range(max_samples_temp))
|
||||
if max_samples is not None: # truncate dataset
|
||||
dataset = dataset.select(range(min(len(dataset), max_samples)))
|
||||
|
||||
for column_name in ["prompt", "query", "response", "history"]: # align datasets
|
||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
# convert dataset from sharegpt format to alpaca format
|
||||
outputs = {"prompt": [], "query": [], "response": [], "history": []}
|
||||
for msg_list in examples[dataset_attr.messages]:
|
||||
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
|
||||
if len(msg_list) == 0:
|
||||
continue
|
||||
|
||||
msg_pairs = []
|
||||
user_role, assistant_role = None, None
|
||||
for idx in range(0, len(msg_list), 2):
|
||||
if user_role is None and assistant_role is None:
|
||||
user_role = msg_list[idx][dataset_attr.role]
|
||||
assistant_role = msg_list[idx + 1][dataset_attr.role]
|
||||
else:
|
||||
if (
|
||||
msg_list[idx][dataset_attr.role] != user_role
|
||||
or msg_list[idx+1][dataset_attr.role] != assistant_role
|
||||
):
|
||||
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
|
||||
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
|
||||
|
||||
if len(msg_pairs) != 0:
|
||||
outputs["prompt"].append(msg_pairs[-1][0])
|
||||
outputs["query"].append("")
|
||||
outputs["response"].append(msg_pairs[-1][1])
|
||||
outputs["history"].append(msg_pairs[:-1])
|
||||
|
||||
return outputs
|
||||
|
||||
if dataset_attr.formatting == "sharegpt": # convert format
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Converting format of dataset"
|
||||
)
|
||||
|
||||
dataset = dataset.map(
|
||||
convert_format,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
for column_name in ["prompt", "query", "response", "history"]: # align dataset
|
||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
|
||||
if dataset_attr.system_prompt: # add system prompt
|
||||
system_prompt = dataset_attr.system_prompt
|
||||
if data_args.streaming:
|
||||
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
|
||||
dataset = dataset.map(lambda _: {"system": system_prompt})
|
||||
else:
|
||||
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
|
||||
dataset = dataset.add_column("system", [system_prompt] * len(dataset))
|
||||
|
||||
all_datasets.append(dataset)
|
||||
|
||||
@@ -86,7 +135,11 @@ def get_dataset(
|
||||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
||||
return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy)
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
probabilities=data_args.interleave_probs,
|
||||
seed=data_args.seed,
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy.")
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import os
|
||||
import tiktoken
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
|
||||
|
||||
from datasets import load_from_disk
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -12,6 +16,9 @@ if TYPE_CHECKING:
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
@@ -19,9 +26,11 @@ def preprocess_dataset(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||
for i in range(len(examples["prompt"])):
|
||||
query, response = examples["prompt"][i], examples["response"][i]
|
||||
@@ -30,51 +39,58 @@ def preprocess_dataset(
|
||||
system = examples["system"][i] if "system" in examples else None
|
||||
yield query, response, history, system
|
||||
|
||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...`
|
||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding):
|
||||
kwargs = dict(allowed_special="all") # for tiktoken tokenizer (Qwen)
|
||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||
kwargs = dict(allowed_special="all")
|
||||
else:
|
||||
kwargs = dict(add_special_tokens=True)
|
||||
|
||||
if hasattr(tokenizer, "add_bos_token") and hasattr(tokenizer, "add_eos_token"):
|
||||
setattr(tokenizer, "add_bos_token", True) # for LLaMA tokenizer
|
||||
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
|
||||
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
|
||||
setattr(tokenizer, "add_eos_token", True)
|
||||
|
||||
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.max_source_length
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of max_source_length
|
||||
# split by chunks of cutoff_len
|
||||
result = {
|
||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
# make sure the saved tokenizer is the same as the original one
|
||||
if hasattr(tokenizer, "add_eos_token"):
|
||||
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
|
||||
return result
|
||||
|
||||
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
for query, response, history, system in construct_example(examples):
|
||||
input_ids, labels = [], []
|
||||
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
||||
continue
|
||||
|
||||
input_ids, labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||
tokenizer, query, response, history, system
|
||||
)):
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length:
|
||||
target_ids = target_ids[:data_args.max_target_length]
|
||||
total_len = len(source_ids) + len(target_ids)
|
||||
max_source_len = int(data_args.cutoff_len * (len(source_ids) / total_len))
|
||||
max_target_len = int(data_args.cutoff_len * (len(target_ids) / total_len))
|
||||
|
||||
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
|
||||
break
|
||||
if len(source_ids) > max_source_len:
|
||||
source_ids = source_ids[:max_source_len]
|
||||
if len(target_ids) > max_target_len:
|
||||
target_ids = target_ids[:max_target_len]
|
||||
|
||||
if turn_idx != 0 and template.efficient_eos:
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
@@ -86,65 +102,117 @@ def preprocess_dataset(
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
if len(input_ids) > data_args.cutoff_len:
|
||||
input_ids = input_ids[:data_args.cutoff_len]
|
||||
labels = labels[:data_args.cutoff_len]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
input_ids, labels = [], []
|
||||
for query, response, history, system in construct_example(examples):
|
||||
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
||||
continue
|
||||
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||
tokenizer, query, response, history, system
|
||||
)):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
total_length = len(input_ids)
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
for i in range(0, total_length, block_size):
|
||||
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i: i + block_size])
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for query, response, history, system in construct_example(examples):
|
||||
source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, system)
|
||||
if not (isinstance(query, str) and query != ""):
|
||||
continue
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length:
|
||||
target_ids = target_ids[:data_args.max_target_length]
|
||||
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
|
||||
|
||||
if template.efficient_eos:
|
||||
target_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(source_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(source_ids))
|
||||
model_inputs["labels"].append(target_ids)
|
||||
if len(input_ids) > data_args.cutoff_len:
|
||||
input_ids = input_ids[:data_args.cutoff_len]
|
||||
if len(labels) > data_args.cutoff_len:
|
||||
labels = labels[:data_args.cutoff_len]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_pairwise_dataset(examples):
|
||||
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||
for query, response, history, system in construct_example(examples):
|
||||
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
|
||||
continue
|
||||
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
|
||||
|
||||
if len(prompt_ids) > data_args.max_source_length:
|
||||
prompt_ids = prompt_ids[:data_args.max_source_length]
|
||||
if len(chosen_ids) > data_args.max_target_length:
|
||||
chosen_ids = chosen_ids[:data_args.max_target_length]
|
||||
if len(rejected_ids) > data_args.max_target_length:
|
||||
rejected_ids = rejected_ids[:data_args.max_target_length]
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
total_len = len(prompt_ids) + max(len(chosen_ids), len(rejected_ids))
|
||||
max_source_len = int(data_args.cutoff_len * (len(prompt_ids) / total_len))
|
||||
max_target_len = int(data_args.cutoff_len * (max(len(chosen_ids), len(rejected_ids)) / total_len))
|
||||
|
||||
if len(prompt_ids) > max_source_len:
|
||||
prompt_ids = prompt_ids[:max_source_len]
|
||||
if len(chosen_ids) > max_target_len:
|
||||
chosen_ids = chosen_ids[:max_target_len]
|
||||
if len(rejected_ids) > max_target_len:
|
||||
rejected_ids = rejected_ids[:max_target_len]
|
||||
|
||||
model_inputs["prompt_ids"].append(prompt_ids)
|
||||
model_inputs["chosen_ids"].append(chosen_ids)
|
||||
model_inputs["rejected_ids"].append(rejected_ids)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_supervised_dataset_example(example):
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(tokenizer.decode([
|
||||
token_id if token_id != IGNORE_INDEX else tokenizer.pad_token_id for token_id in example["labels"]
|
||||
], skip_special_tokens=False)))
|
||||
print("labels:\n{}".format(
|
||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
||||
))
|
||||
|
||||
def print_pairwise_dataset_example(example):
|
||||
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
|
||||
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||
@@ -152,42 +220,53 @@ def preprocess_dataset(
|
||||
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||
|
||||
def print_unsupervised_dataset_example(example):
|
||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
if stage == "pt":
|
||||
dataset = dataset.filter(lambda example: example["prompt"])
|
||||
preprocess_function = preprocess_pretrain_dataset
|
||||
preprocess_func = preprocess_pretrain_dataset
|
||||
print_function = print_unsupervised_dataset_example
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
|
||||
preprocess_function = preprocess_supervised_dataset
|
||||
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
|
||||
print_function = print_supervised_dataset_example
|
||||
elif stage == "rm":
|
||||
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
|
||||
preprocess_function = preprocess_pairwise_dataset
|
||||
preprocess_func = preprocess_pairwise_dataset
|
||||
print_function = print_pairwise_dataset_example
|
||||
else:
|
||||
dataset = dataset.filter(lambda example: example["prompt"])
|
||||
preprocess_function = preprocess_unsupervised_dataset
|
||||
preprocess_func = preprocess_unsupervised_dataset
|
||||
print_function = print_unsupervised_dataset_example
|
||||
|
||||
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
return load_from_disk(data_args.cache_path)
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Running tokenizer on dataset"
|
||||
)
|
||||
|
||||
dataset = dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
preprocess_func,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
print_function(next(iter(dataset)))
|
||||
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
||||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.cache_path)
|
||||
raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.")
|
||||
|
||||
if training_args.should_log:
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
raise RuntimeError("Empty dataset!")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -13,9 +13,11 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
EXT2TYPE = {
|
||||
"arrow": "arrow",
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json",
|
||||
"parquet": "parquet",
|
||||
"txt": "text"
|
||||
}
|
||||
|
||||
|
||||
1
src/llmtuner/eval/__init__.py
Normal file
1
src/llmtuner/eval/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from llmtuner.eval.engine import Evaluator
|
||||
3
src/llmtuner/eval/constants.py
Normal file
3
src/llmtuner/eval/constants.py
Normal file
@@ -0,0 +1,3 @@
|
||||
CHOICES = ["A", "B", "C", "D"]
|
||||
|
||||
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
110
src/llmtuner/eval/engine.py
Normal file
110
src/llmtuner/eval/engine.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import tiktoken
|
||||
import numpy as np
|
||||
from tqdm import tqdm, trange
|
||||
from datasets import load_dataset
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llmtuner.eval.constants import CHOICES, SUBJECTS
|
||||
from llmtuner.eval.parser import get_eval_args
|
||||
from llmtuner.eval.template import get_eval_template
|
||||
from llmtuner.extras.misc import dispatch_model
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
|
||||
|
||||
class Evaluator:
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||
self.model = dispatch_model(self.model)
|
||||
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||
self.choice_inputs = self._encode_choices()
|
||||
|
||||
def _encode_choices(self) -> List[int]:
|
||||
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||
kwargs = dict(allowed_special="all")
|
||||
else:
|
||||
kwargs = dict(add_special_tokens=False)
|
||||
|
||||
return [self.tokenizer.encode(self.eval_template.prefix + ch, **kwargs)[-1] for ch in CHOICES]
|
||||
|
||||
@torch.inference_mode()
|
||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||
logits = self.model(**batch_input).logits
|
||||
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
||||
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
||||
choice_probs = torch.nn.functional.softmax(word_probs[:, self.choice_inputs], dim=-1).detach()
|
||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||
|
||||
def eval(self) -> None:
|
||||
mapping = os.path.join(self.eval_args.task_dir, self.eval_args.task, "mapping.json")
|
||||
with open(mapping, "r", encoding="utf-8") as f:
|
||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||
|
||||
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||
results = {}
|
||||
for subject in pbar:
|
||||
dataset = load_dataset(
|
||||
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||
name=subject,
|
||||
download_mode="force_redownload"
|
||||
)
|
||||
pbar.set_postfix_str(categorys[subject]["name"])
|
||||
inputs, outputs, labels = [], [], []
|
||||
for i in trange(len(dataset[self.data_args.split]), desc="Formatting batches", position=1, leave=False):
|
||||
support_set = dataset["train"].shuffle().select(range(min(self.eval_args.n_shot, len(dataset["train"]))))
|
||||
query, resp, history = self.eval_template.format_example(
|
||||
target_data=dataset[self.data_args.split][i],
|
||||
support_set=support_set,
|
||||
subject_name=categorys[subject]["name"],
|
||||
use_history=self.template.use_history
|
||||
)
|
||||
input_ids, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, query=query, resp=resp, history=history
|
||||
)
|
||||
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
||||
labels.append(resp)
|
||||
|
||||
for i in trange(0, len(inputs), self.eval_args.batch_size, desc="Predicting batches", position=1, leave=False):
|
||||
batch_input = self.tokenizer.pad(
|
||||
inputs[i : i + self.eval_args.batch_size], return_attention_mask=True, return_tensors="pt"
|
||||
).to(self.model.device)
|
||||
preds = self.batch_inference(batch_input)
|
||||
outputs += preds
|
||||
|
||||
corrects = (np.array(outputs) == np.array(labels))
|
||||
category_name = categorys[subject]["category"]
|
||||
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
||||
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
||||
results[subject] = {str(i): outputs[i] for i in range(len(outputs))}
|
||||
|
||||
pbar.close()
|
||||
self._save_results(category_corrects, results)
|
||||
|
||||
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
||||
score_info = "\n".join([
|
||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||
for category_name, category_correct in category_corrects.items() if len(category_correct)
|
||||
])
|
||||
print(score_info)
|
||||
if self.eval_args.save_dir is not None:
|
||||
os.makedirs(self.eval_args.save_dir, exist_ok=False)
|
||||
with open(os.path.join(self.eval_args.save_dir, "results.json"), "w", encoding="utf-8", newline="\n") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
with open(os.path.join(self.eval_args.save_dir, "results.log"), "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(score_info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluator = Evaluator()
|
||||
evaluator.eval()
|
||||
49
src/llmtuner/eval/parser.py
Normal file
49
src/llmtuner/eval/parser.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from llmtuner.extras.misc import parse_args
|
||||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments
|
||||
)
|
||||
|
||||
|
||||
def parse_eval_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments
|
||||
))
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def get_eval_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
EvaluationArguments,
|
||||
FinetuningArguments
|
||||
]:
|
||||
model_args, data_args, eval_args, finetuning_args = parse_eval_args(args)
|
||||
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
transformers.set_seed(eval_args.seed)
|
||||
|
||||
return model_args, data_args, eval_args, finetuning_args
|
||||
86
src/llmtuner/eval/template.py
Normal file
86
src/llmtuner/eval/template.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||
|
||||
from llmtuner.eval.constants import CHOICES
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalTemplate:
|
||||
|
||||
system: str
|
||||
choice: str
|
||||
answer: str
|
||||
prefix: str
|
||||
|
||||
def parse_example(
|
||||
self,
|
||||
example: Dict[str, str]
|
||||
) -> Tuple[str, str]:
|
||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||
|
||||
def format_example(
|
||||
self,
|
||||
target_data: Dict[str, str],
|
||||
support_set: "Dataset",
|
||||
subject_name: str,
|
||||
use_history: bool
|
||||
) -> Tuple[str, str, List[Tuple[str, str]]]:
|
||||
query, resp = self.parse_example(target_data)
|
||||
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
||||
|
||||
if len(history):
|
||||
temp = history.pop(0)
|
||||
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
||||
else:
|
||||
query = self.system.format(subject=subject_name) + query
|
||||
|
||||
if not use_history:
|
||||
query = "\n\n".join(["".join(item) for item in history] + [query])
|
||||
history = []
|
||||
return query.strip(), resp, history
|
||||
|
||||
|
||||
eval_templates: Dict[str, EvalTemplate] = {}
|
||||
|
||||
|
||||
def register_eval_template(
|
||||
name: str,
|
||||
system: str,
|
||||
choice: str,
|
||||
answer: str,
|
||||
prefix: str
|
||||
) -> None:
|
||||
eval_templates[name] = EvalTemplate(
|
||||
system=system,
|
||||
choice=choice,
|
||||
answer=answer,
|
||||
prefix=prefix
|
||||
)
|
||||
|
||||
|
||||
def get_eval_template(name: str) -> EvalTemplate:
|
||||
eval_template = eval_templates.get(name, None)
|
||||
assert eval_template is not None, "Template {} does not exist.".format(name)
|
||||
return eval_template
|
||||
|
||||
|
||||
register_eval_template(
|
||||
name="en",
|
||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\nAnswer: ",
|
||||
prefix=" "
|
||||
)
|
||||
|
||||
|
||||
register_eval_template(
|
||||
name="zh",
|
||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\n答案:",
|
||||
prefix="\n"
|
||||
)
|
||||
@@ -5,9 +5,7 @@ from typing import TYPE_CHECKING
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
@@ -27,14 +25,18 @@ class SavePeftModelCallback(TrainerCallback):
|
||||
"""
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
|
||||
model = kwargs.pop("model")
|
||||
if getattr(model, "is_peft_model", False):
|
||||
getattr(model, "pretrained_model").save_pretrained(output_dir)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
|
||||
model = kwargs.pop("model")
|
||||
if getattr(model, "is_peft_model", False):
|
||||
getattr(model, "pretrained_model").save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
@@ -64,7 +66,7 @@ class LogCallback(TrainerCallback):
|
||||
self.in_training = True
|
||||
self.start_time = time.time()
|
||||
self.max_steps = state.max_steps
|
||||
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)):
|
||||
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
|
||||
logger.warning("Previous log file in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
||||
|
||||
@@ -132,6 +134,11 @@ class LogCallback(TrainerCallback):
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time
|
||||
)
|
||||
if self.runner is not None:
|
||||
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
||||
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
||||
))
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from collections import defaultdict, OrderedDict
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
TRAINING_STAGES = {
|
||||
@@ -14,69 +16,222 @@ TRAINING_STAGES = {
|
||||
"Pre-Training": "pt"
|
||||
}
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
"LLaMA-7B": "huggyllama/llama-7b",
|
||||
"LLaMA-13B": "huggyllama/llama-13b",
|
||||
"LLaMA-30B": "huggyllama/llama-30b",
|
||||
"LLaMA-65B": "huggyllama/llama-65b",
|
||||
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
||||
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
|
||||
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
|
||||
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
|
||||
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
|
||||
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
|
||||
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
|
||||
"BLOOM-560M": "bigscience/bloom-560m",
|
||||
"BLOOM-3B": "bigscience/bloom-3b",
|
||||
"BLOOM-7B1": "bigscience/bloom-7b1",
|
||||
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
||||
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
||||
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
|
||||
"Falcon-7B": "tiiuae/falcon-7b",
|
||||
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
|
||||
"Falcon-40B": "tiiuae/falcon-40b",
|
||||
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
||||
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
|
||||
"Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
|
||||
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
|
||||
"Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
|
||||
"Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
|
||||
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
|
||||
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
|
||||
"InternLM-7B": "internlm/internlm-7b",
|
||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
|
||||
"Qwen-7B": "Qwen/Qwen-7B",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
"XVERSE-13B": "xverse/XVERSE-13B",
|
||||
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat",
|
||||
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
|
||||
}
|
||||
LAYERNORM_NAMES = {"norm", "ln"}
|
||||
|
||||
DEFAULT_MODULE = {
|
||||
"LLaMA": "q_proj,v_proj",
|
||||
"LLaMA2": "q_proj,v_proj",
|
||||
"ChineseLLaMA2": "q_proj,v_proj",
|
||||
"BLOOM": "query_key_value",
|
||||
"BLOOMZ": "query_key_value",
|
||||
"Falcon": "query_key_value",
|
||||
"Baichuan": "W_pack",
|
||||
"Baichuan2": "W_pack",
|
||||
"InternLM": "q_proj,v_proj",
|
||||
"Qwen": "c_attn",
|
||||
"XVERSE": "q_proj,v_proj",
|
||||
"ChatGLM2": "query_key_value"
|
||||
}
|
||||
SUPPORTED_MODELS = OrderedDict()
|
||||
|
||||
DEFAULT_TEMPLATE = {
|
||||
"LLaMA2": "llama2",
|
||||
"ChineseLLaMA2": "llama2_zh",
|
||||
"Baichuan": "baichuan",
|
||||
"Baichuan2": "baichuan2",
|
||||
"InternLM": "intern",
|
||||
"Qwen": "chatml",
|
||||
"XVERSE": "xverse",
|
||||
"ChatGLM2": "chatglm2"
|
||||
}
|
||||
DEFAULT_MODULE = defaultdict(str)
|
||||
|
||||
DEFAULT_TEMPLATE = defaultdict(str)
|
||||
|
||||
|
||||
def register_model_group(
|
||||
models: Dict[str, str],
|
||||
module: Optional[str] = None,
|
||||
template: Optional[str] = None
|
||||
) -> None:
|
||||
prefix = None
|
||||
for name, path in models.items():
|
||||
if prefix is None:
|
||||
prefix = name.split("-")[0]
|
||||
else:
|
||||
assert prefix == name.split("-")[0], "prefix should be identical."
|
||||
SUPPORTED_MODELS[name] = path
|
||||
if module is not None:
|
||||
DEFAULT_MODULE[prefix] = module
|
||||
if template is not None:
|
||||
DEFAULT_TEMPLATE[prefix] = template
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan-7B-Base": "baichuan-inc/Baichuan-7B",
|
||||
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
|
||||
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat"
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base",
|
||||
"Baichuan2-13B-Base": "baichuan-inc/Baichuan2-13B-Base",
|
||||
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
|
||||
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat"
|
||||
},
|
||||
module="W_pack",
|
||||
template="baichuan2"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BLOOM-560M": "bigscience/bloom-560m",
|
||||
"BLOOM-3B": "bigscience/bloom-3b",
|
||||
"BLOOM-7B1": "bigscience/bloom-7b1"
|
||||
},
|
||||
module="query_key_value"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
||||
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
||||
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt"
|
||||
},
|
||||
module="query_key_value"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base",
|
||||
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat"
|
||||
},
|
||||
template="bluelm"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm2"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
|
||||
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b"
|
||||
},
|
||||
module="query_key_value",
|
||||
template="chatglm3"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
|
||||
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
|
||||
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
|
||||
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b"
|
||||
},
|
||||
template="llama2_zh"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Falcon-7B": "tiiuae/falcon-7b",
|
||||
"Falcon-40B": "tiiuae/falcon-40b",
|
||||
"Falcon-180B": "tiiuae/falcon-180B",
|
||||
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
|
||||
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
||||
"Falcon-180B-Chat": "tiiuae/falcon-180B-chat"
|
||||
},
|
||||
module="query_key_value",
|
||||
template="falcon"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": "internlm/internlm-7b",
|
||||
"InternLM-20B": "internlm/internlm-20b",
|
||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
|
||||
"InternLM-20B-Chat": "internlm/internlm-chat-20b"
|
||||
},
|
||||
template="intern"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LingoWhale-8B": "deeplang-ai/LingoWhale-8B"
|
||||
},
|
||||
module="qkv_proj"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA-7B": "huggyllama/llama-7b",
|
||||
"LLaMA-13B": "huggyllama/llama-13b",
|
||||
"LLaMA-30B": "huggyllama/llama-30b",
|
||||
"LLaMA-65B": "huggyllama/llama-65b"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
||||
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
|
||||
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
|
||||
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf"
|
||||
},
|
||||
template="llama2"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-7B": "mistralai/Mistral-7B-v0.1",
|
||||
"Mistral-7B-Chat": "mistralai/Mistral-7B-Instruct-v0.1"
|
||||
},
|
||||
template="mistral"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Phi1.5-1.3B": "microsoft/phi-1_5"
|
||||
},
|
||||
module="Wqkv"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen-7B": "Qwen/Qwen-7B",
|
||||
"Qwen-14B": "Qwen/Qwen-14B",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat"
|
||||
},
|
||||
module="c_attn",
|
||||
template="qwen"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Skywork-13B-Base": "Skywork/Skywork-13B-base"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"XVERSE-7B": "xverse/XVERSE-7B",
|
||||
"XVERSE-13B": "xverse/XVERSE-13B",
|
||||
"XVERSE-65B": "xverse/XVERSE-65B",
|
||||
"XVERSE-7B-Chat": "xverse/XVERSE-7B-Chat",
|
||||
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat"
|
||||
},
|
||||
template="xverse"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Yi-6B": "01-ai/Yi-6B",
|
||||
"Yi-34B": "01-ai/Yi-34B"
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,9 +1,25 @@
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
|
||||
try:
|
||||
from transformers.utils import (
|
||||
is_torch_bf16_cpu_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_npu_available
|
||||
)
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
|
||||
except ImportError:
|
||||
_is_fp16_available = torch.cuda.is_available()
|
||||
_is_bf16_available = torch.cuda.is_bf16_supported()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
@@ -49,7 +65,22 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
return trainable_params, all_param
|
||||
|
||||
|
||||
def get_logits_processor() -> LogitsProcessorList:
|
||||
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
"""
|
||||
if _is_bf16_available and model_dtype == torch.bfloat16:
|
||||
return torch.bfloat16
|
||||
elif _is_fp16_available:
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def get_logits_processor() -> "LogitsProcessorList":
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
@@ -65,6 +96,17 @@ def torch_gc() -> None:
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
return parser.parse_args_into_dataclasses()
|
||||
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
|
||||
@@ -1,305 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Modified from:
|
||||
# [1] https://huggingface.co/Birchlabs/flash_llama/blob/main/modeling_flash_llama.py
|
||||
# [2] https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama2_flash_attn_monkey_patch.py
|
||||
# [3] https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py
|
||||
# [4] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
# With fix from Alex Birch: https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/17
|
||||
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
from transformers.utils import logging
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
)
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
flash_attn_v2_installed = True
|
||||
print('>>>> Flash Attention installed')
|
||||
except ImportError:
|
||||
flash_attn_v2_installed = False
|
||||
raise ImportError('Please install Flash Attention: `pip install flash-attn --no-build-isolation`')
|
||||
|
||||
try:
|
||||
from flash_attn.layers.rotary import apply_rotary_emb_func
|
||||
flash_rope_installed = True
|
||||
print('>>>> Flash RoPE installed')
|
||||
except ImportError:
|
||||
flash_rope_installed = False
|
||||
raise ImportError('Please install RoPE kernels: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`')
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LlamaRMSNorm(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return (self.weight * hidden_states).to(input_dtype) # for fp32 weight
|
||||
|
||||
|
||||
class FlashRotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base=10000.0,
|
||||
interleaved=False,
|
||||
scale_base=None,
|
||||
scaling_factor=1.0,
|
||||
pos_idx_in_fp32=True,
|
||||
device=None
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.base = float(base)
|
||||
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
||||
# Generate and save the inverse frequency buffer (non trainable)
|
||||
inv_freq = self._compute_inv_freq(device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.interleaved = interleaved
|
||||
self.scale_base = scale_base
|
||||
self.scaling_factor = scaling_factor
|
||||
scale = (
|
||||
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
||||
if scale_base is not None else None
|
||||
)
|
||||
self.register_buffer("scale", scale)
|
||||
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self._cos_k_cached = None
|
||||
self._sin_k_cached = None
|
||||
|
||||
def _compute_inv_freq(self, device=None):
|
||||
return 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
||||
|
||||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
||||
if (
|
||||
seqlen > self._seq_len_cached or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
or (self.training and self._cos_cached.is_inference())
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
if self.pos_idx_in_fp32:
|
||||
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
||||
t /= self.scaling_factor
|
||||
if self.inv_freq.dtype != torch.float32:
|
||||
inv_freq = self.inv_freq.to(torch.float32)
|
||||
else:
|
||||
inv_freq = self.inv_freq
|
||||
else:
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
t /= self.scaling_factor
|
||||
inv_freq = self.inv_freq
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
else:
|
||||
power = (
|
||||
(torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2) / self.scale_base
|
||||
)
|
||||
scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
|
||||
# We want the multiplication by scale to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
q: (batch, seqlen, nheads, headdim)
|
||||
k: (batch, seqlen, nheads, headdim)
|
||||
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
||||
token in the batch.
|
||||
"""
|
||||
self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype)
|
||||
if self.scale is None:
|
||||
return apply_rotary_emb_func(
|
||||
q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
||||
self.interleaved, True # inplace=True
|
||||
), apply_rotary_emb_func(
|
||||
k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
||||
self.interleaved, True # inplace=True
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
r"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, slen, _, num_key_value_heads, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, :, :, None, :].expand(batch, slen, 2, num_key_value_heads, n_rep, head_dim)
|
||||
return hidden_states.reshape(batch, slen, 2, num_key_value_heads * n_rep, head_dim)
|
||||
|
||||
|
||||
class LlamaAttention(torch.nn.Module):
|
||||
|
||||
def __init__(self, config: "LlamaConfig"):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.q_proj = torch.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
self.register_buffer(
|
||||
"norm_factor",
|
||||
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
if self.config.rope_scaling is None:
|
||||
scaling_factor = 1
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
assert scaling_type == "linear"
|
||||
|
||||
self.rotary_emb = FlashRotaryEmbedding(
|
||||
self.head_dim, base=10000, interleaved=False, scaling_factor=scaling_factor
|
||||
)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, h_size = hidden_states.size()
|
||||
|
||||
has_layer_past = past_key_value is not None
|
||||
|
||||
if has_layer_past:
|
||||
past_kv = past_key_value[0]
|
||||
past_len = past_key_value[1]
|
||||
else:
|
||||
past_len = 0
|
||||
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
|
||||
q = q.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
q, k = self.rotary_emb(q, k, past_len)
|
||||
|
||||
kv = torch.stack([k, v], 2)
|
||||
kv = repeat_kv(kv, self.num_key_value_groups)
|
||||
|
||||
# Cache QKV values
|
||||
if has_layer_past:
|
||||
new_len = past_len+q.size(1)
|
||||
if new_len > past_kv.size(1):
|
||||
past_kv = torch.cat(
|
||||
[past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1
|
||||
)
|
||||
past_kv[:, past_len:new_len] = kv
|
||||
kv = past_kv[:, :new_len]
|
||||
else:
|
||||
past_kv = kv
|
||||
|
||||
past_key_value = (past_kv, past_len + q.size(1)) if use_cache else None
|
||||
|
||||
if attention_mask is not None:
|
||||
# varlen, ignore padding tokens, efficient for large batch with many paddings
|
||||
logger.warning_once("padded sequences is less efficient")
|
||||
|
||||
unpadded_kv, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(kv, attention_mask)
|
||||
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):])
|
||||
attn_outputs = flash_attn_varlen_kvpacked_func(
|
||||
unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
dropout_p=0.0, softmax_scale=1.0/self.norm_factor,
|
||||
causal=(not has_layer_past), return_attn_probs=output_attentions
|
||||
)
|
||||
|
||||
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
||||
attn_output = pad_input(
|
||||
attn_output, indices_q, bsz, q_len
|
||||
).reshape(bsz, q_len, h_size)
|
||||
attn_weights = attn_outputs[2] if output_attentions else None
|
||||
|
||||
else:
|
||||
# no padding tokens, more efficient
|
||||
attn_outputs = flash_attn_kvpacked_func(
|
||||
q, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor,
|
||||
causal=(not has_layer_past), return_attn_probs=output_attentions
|
||||
)
|
||||
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
||||
attn_output = attn_output.reshape(bsz, q_len, h_size)
|
||||
attn_weights = attn_outputs[2] if output_attentions else None
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as flash attention
|
||||
# takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
|
||||
def _prepare_decoder_attention_mask(
|
||||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||
):
|
||||
# [bsz, seq_len]
|
||||
if past_key_values_length > 0 and attention_mask is not None:
|
||||
attention_mask = torch.cat(
|
||||
(
|
||||
torch.full(
|
||||
(input_shape[0], past_key_values_length),
|
||||
True,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device
|
||||
),
|
||||
attention_mask
|
||||
),
|
||||
dim=-1
|
||||
)
|
||||
|
||||
if attention_mask is not None and torch.all(attention_mask):
|
||||
return None # This uses the faster call when training with full samples
|
||||
|
||||
return attention_mask
|
||||
221
src/llmtuner/extras/patches/llama_patch.py
Normal file
221
src/llmtuner/extras/patches/llama_patch.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Tuple
|
||||
from transformers.utils import logging
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
is_flash_attn_2_available = False
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
||||
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
||||
is_flash_attn_2_available = True
|
||||
except ImportError:
|
||||
is_flash_attn_2_available = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
class LlamaShiftShortAttention(LlamaAttention):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None: # reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
if getattr(self, "num_key_value_groups"):
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||
state = torch.cat((
|
||||
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
||||
), dim=2)
|
||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = torch.cat((
|
||||
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
||||
))
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LlamaFlashAttention2(LlamaAttention):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# LlamaFlashAttention2 attention does not support output_attentions
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None: # reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# cast to half precision
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||
query_states = query_states.to(self.config.torch_dtype)
|
||||
key_states = key_states.to(self.config.torch_dtype)
|
||||
value_states = value_states.to(self.config.torch_dtype)
|
||||
|
||||
if getattr(self, "num_key_value_groups", None):
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
state = torch.cat((
|
||||
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
||||
), dim=2)
|
||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
||||
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.reshape(bsz * num_groups, groupsz)
|
||||
|
||||
if attention_mask is not None:
|
||||
logger.warning_once("Padded sequences are less efficient in FlashAttention.")
|
||||
# -q_len: assumes left padding when q_len != kv_len
|
||||
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
|
||||
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
|
||||
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
unpadded_q,
|
||||
unpadded_k,
|
||||
unpadded_v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = torch.cat((
|
||||
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
||||
))
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as flash attention
|
||||
# takes a boolean padding_mask. Fills in the past kv length for use in forward.
|
||||
def _prepare_decoder_attention_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_shape: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int
|
||||
) -> torch.Tensor:
|
||||
if attention_mask is not None and torch.all(attention_mask):
|
||||
return None # This uses the faster call when training with full samples
|
||||
|
||||
return attention_mask
|
||||
@@ -1,21 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from transformers.trainer import WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
||||
if not os.path.exists(vhead_file):
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
||||
return False
|
||||
vhead_params = torch.load(vhead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||
return True
|
||||
@@ -138,16 +138,15 @@ class Template:
|
||||
token_ids = []
|
||||
for elem in context:
|
||||
if isinstance(elem, str):
|
||||
if len(elem) == 0:
|
||||
continue
|
||||
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
|
||||
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
|
||||
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
|
||||
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
|
||||
if len(elem) != 0:
|
||||
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
|
||||
elif isinstance(elem, dict):
|
||||
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
|
||||
|
||||
return token_ids
|
||||
|
||||
@@ -226,89 +225,8 @@ def get_template_and_fix_tokenizer(
|
||||
return template
|
||||
|
||||
|
||||
r"""
|
||||
Supports language model inference without histories.
|
||||
"""
|
||||
register_template(
|
||||
name="vanilla",
|
||||
prefix=[],
|
||||
prompt=[
|
||||
"{{query}}"
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
use_history=False
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Default template.
|
||||
"""
|
||||
register_template(
|
||||
name="default",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\nAssistant: "
|
||||
],
|
||||
system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
sep=[
|
||||
"\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
|
||||
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
|
||||
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
|
||||
"""
|
||||
register_template(
|
||||
name="llama2",
|
||||
prefix=[
|
||||
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST] "
|
||||
],
|
||||
system=(
|
||||
"You are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
|
||||
https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
|
||||
"""
|
||||
register_template(
|
||||
name="llama2_zh",
|
||||
prefix=[
|
||||
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST] "
|
||||
],
|
||||
system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
||||
https://github.com/ymcui/Chinese-LLaMA-Alpaca
|
||||
"""
|
||||
register_template(
|
||||
name="alpaca",
|
||||
@@ -329,102 +247,9 @@ register_template(
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
|
||||
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
|
||||
"""
|
||||
register_template(
|
||||
name="vicuna",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"USER: {{query}} ASSISTANT: "
|
||||
],
|
||||
system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
||||
"""
|
||||
register_template(
|
||||
name="belle",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\n\nBelle: "
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://github.com/CVI-SZU/Linly
|
||||
"""
|
||||
register_template(
|
||||
name="linly",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"User: {{query}}\nBot: "
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://github.com/Neutralzz/BiLLa
|
||||
"""
|
||||
register_template(
|
||||
name="billa",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\nAssistant: "
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
|
||||
"""
|
||||
register_template(
|
||||
name="ziya",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<human>"},
|
||||
":{{query}}\n",
|
||||
{"token": "<bot>"},
|
||||
":"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/qhduan/aquilachat-7b
|
||||
Supports: https://huggingface.co/BAAI/AquilaChat-7B
|
||||
https://huggingface.co/BAAI/AquilaChat2-7B
|
||||
https://huggingface.co/BAAI/AquilaChat2-34B
|
||||
"""
|
||||
register_template(
|
||||
name="aquila",
|
||||
@@ -432,7 +257,7 @@ register_template(
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}###Assistant: "
|
||||
"Human: {{query}}###Assistant:"
|
||||
],
|
||||
system=(
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
@@ -440,30 +265,9 @@ register_template(
|
||||
),
|
||||
sep=[
|
||||
"###"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
||||
"""
|
||||
register_template(
|
||||
name="intern",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"<|User|>:{{query}}",
|
||||
{"token": "<eoh>"},
|
||||
"\n<|Bot|>:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
{"token": "<eoa>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<eoa>"
|
||||
"</s>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
@@ -508,6 +312,300 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
||||
"""
|
||||
register_template(
|
||||
name="belle",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\n\nBelle: "
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/vivo-ai/BlueLM-7B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="bluelm",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "[|Human|]:"},
|
||||
"{{query}}",
|
||||
{"token": "[|AI|]:"}
|
||||
],
|
||||
system="",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/THUDM/chatglm2-6b
|
||||
"""
|
||||
register_template(
|
||||
name="chatglm2",
|
||||
prefix=[
|
||||
{"token": "[gMASK]"},
|
||||
{"token": "sop"},
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n\n"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/THUDM/chatglm3-6b
|
||||
"""
|
||||
register_template(
|
||||
name="chatglm3",
|
||||
prefix=[
|
||||
{"token": "[gMASK]"},
|
||||
{"token": "sop"},
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|user|>"},
|
||||
"\n",
|
||||
"{{query}}",
|
||||
{"token": "<|assistant|>"}
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
stop_words=[
|
||||
"<|user|>",
|
||||
"<|observation|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-instruct
|
||||
https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct
|
||||
https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct
|
||||
"""
|
||||
register_template(
|
||||
name="deepseek",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"### Instruction:\n{{query}}\n\n### Response:\n"
|
||||
],
|
||||
system=(
|
||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||
"For politically sensitive questions, security and privacy issues, "
|
||||
"and other non-computer science questions, you will refuse to answer."
|
||||
),
|
||||
sep=[
|
||||
"\n",
|
||||
{"token": "<|EOT|>"},
|
||||
"\n\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|EOT|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Default template.
|
||||
"""
|
||||
register_template(
|
||||
name="default",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\nAssistant:"
|
||||
],
|
||||
system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
sep=[
|
||||
"\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/tiiuae/falcon-180B-chat
|
||||
"""
|
||||
register_template(
|
||||
name="falcon",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"User: {{query}}\nFalcon:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
||||
https://huggingface.co/internlm/internlm-chat-20b
|
||||
"""
|
||||
register_template(
|
||||
name="intern",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"<|User|>:{{query}}",
|
||||
{"token": "<eoh>"},
|
||||
"\n<|Bot|>:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
{"token": "<eoa>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<eoa>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
|
||||
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
|
||||
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
|
||||
"""
|
||||
register_template(
|
||||
name="llama2",
|
||||
prefix=[
|
||||
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST]"
|
||||
],
|
||||
system=(
|
||||
"You are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
|
||||
https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
|
||||
"""
|
||||
register_template(
|
||||
name="llama2_zh",
|
||||
prefix=[
|
||||
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST]"
|
||||
],
|
||||
system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
||||
"""
|
||||
register_template(
|
||||
name="mistral",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST]"
|
||||
],
|
||||
system="",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/openchat/openchat_3.5
|
||||
"""
|
||||
register_template(
|
||||
name="openchat",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"GPT4 Correct User: {{query}}",
|
||||
{"token": "<|end_of_turn|>"},
|
||||
"GPT4 Correct Assistant:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
{"token": "<|end_of_turn|>"}
|
||||
],
|
||||
stop_words=[
|
||||
"<|end_of_turn|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
|
||||
https://huggingface.co/Qwen/Qwen-14B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="qwen",
|
||||
prefix=[
|
||||
{"token": "<|im_start|>"},
|
||||
"system\n{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|im_start|>"},
|
||||
"user\n{{query}}",
|
||||
{"token": "<|im_end|>"},
|
||||
"\n",
|
||||
{"token": "<|im_start|>"},
|
||||
"assistant\n"
|
||||
],
|
||||
system="You are a helpful assistant.",
|
||||
sep=[
|
||||
{"token": "<|im_end|>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|im_end|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
|
||||
https://huggingface.co/HuggingFaceH4/starchat-beta
|
||||
@@ -538,57 +636,43 @@ register_template(
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
|
||||
Supports language model inference without histories.
|
||||
"""
|
||||
register_template(
|
||||
name="chatml",
|
||||
prefix=[
|
||||
{"token": "<|im_start|>"},
|
||||
"system\n{{system}}"
|
||||
],
|
||||
name="vanilla",
|
||||
prefix=[],
|
||||
prompt=[
|
||||
{"token": "<|im_start|>"},
|
||||
"user\n{{query}}",
|
||||
{"token": "<|im_end|>"},
|
||||
"\n",
|
||||
{"token": "<|im_start|>"},
|
||||
"assistant\n"
|
||||
"{{query}}"
|
||||
],
|
||||
system="You are a helpful assistant.",
|
||||
sep=[
|
||||
{"token": "<|im_end|>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|im_end|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
system="",
|
||||
sep=[],
|
||||
use_history=False
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/THUDM/chatglm2-6b
|
||||
Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5
|
||||
https://huggingface.co/lmsys/vicuna-13b-v1.5
|
||||
"""
|
||||
register_template(
|
||||
name="chatglm2",
|
||||
name="vicuna",
|
||||
prefix=[
|
||||
{"token": "[gMASK]"},
|
||||
{"token": "sop"},
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
|
||||
"USER: {{query}} ASSISTANT:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n\n"
|
||||
],
|
||||
efficient_eos=True
|
||||
system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/xverse/XVERSE-13B-Chat
|
||||
Supports: https://huggingface.co/xverse/XVERSE-7B-Chat
|
||||
https://huggingface.co/xverse/XVERSE-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="xverse",
|
||||
@@ -601,3 +685,85 @@ register_template(
|
||||
system="",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/wenge-research/yayi-7b
|
||||
https://huggingface.co/wenge-research/yayi-7b-llama2
|
||||
https://huggingface.co/wenge-research/yayi-13b-llama2
|
||||
"""
|
||||
register_template(
|
||||
name="yayi",
|
||||
prefix=[
|
||||
{"token": "<|System|>"},
|
||||
":\n{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|Human|>"},
|
||||
":\n{{query}}\n\n",
|
||||
{"token": "<|YaYi|>"},
|
||||
":"
|
||||
],
|
||||
system=(
|
||||
"You are a helpful, respectful and honest assistant named YaYi "
|
||||
"developed by Beijing Wenge Technology Co.,Ltd. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
sep=[
|
||||
"\n\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|End|>"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha
|
||||
https://huggingface.co/HuggingFaceH4/zephyr-7b-beta
|
||||
"""
|
||||
register_template(
|
||||
name="zephyr",
|
||||
prefix=[
|
||||
{"token": "<|system|>"},
|
||||
"\n{{system}}",
|
||||
{"token": "</s>"}
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|user|>"},
|
||||
"\n{{query}}",
|
||||
{"token": "</s>"},
|
||||
{"token": "<|assistant|>"}
|
||||
],
|
||||
system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
|
||||
https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1
|
||||
https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="ziya",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<human>"},
|
||||
":{{query}}\n",
|
||||
{"token": "<bot>"},
|
||||
":"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
from .general_args import GeneralArguments
|
||||
from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
|
||||
@@ -11,11 +11,17 @@ class DatasetAttr:
|
||||
dataset_name: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
subset: Optional[str] = None
|
||||
ranking: Optional[bool] = False
|
||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
||||
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
messages: Optional[str] = "conversations"
|
||||
role: Optional[str] = "from"
|
||||
content: Optional[str] = "value"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
@@ -31,28 +37,36 @@ class DataArguments:
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
default="alpaca_en",
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
|
||||
)
|
||||
dataset_dir: Optional[str] = field(
|
||||
default="data",
|
||||
metadata={"help": "The name of the folder containing datasets."}
|
||||
metadata={"help": "Path to the folder containing the datasets."}
|
||||
)
|
||||
split: Optional[str] = field(
|
||||
default="train",
|
||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||
)
|
||||
cutoff_len: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={"help": "The maximum length of the model inputs after tokenization."}
|
||||
)
|
||||
train_on_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to disable the mask on the prompt or not."}
|
||||
)
|
||||
streaming: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable streaming mode."}
|
||||
metadata={"help": "Enable dataset streaming."}
|
||||
)
|
||||
buffer_size: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
|
||||
default=16384,
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
|
||||
)
|
||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||
default="concat",
|
||||
metadata={"help": "Strategy to use in dataset mixing."}
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
|
||||
)
|
||||
interleave_probs: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -66,14 +80,6 @@ class DataArguments:
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."}
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum total input sequence length after tokenization."}
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum total output sequence length after tokenization."}
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
||||
@@ -94,11 +100,35 @@ class DataArguments:
|
||||
default=0,
|
||||
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
||||
)
|
||||
sft_packing: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
||||
)
|
||||
cache_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the preprocessed datasets."}
|
||||
)
|
||||
|
||||
def init_for_training(self): # support mixing multiple datasets
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
def __post_init__(self):
|
||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||
raise ValueError("Streaming mode should have an integer val size.")
|
||||
|
||||
if self.streaming and self.max_samples is not None:
|
||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||
|
||||
if self.streaming and self.cache_path:
|
||||
raise ValueError("`cache_path` is incompatible with `streaming`.")
|
||||
|
||||
def init_for_training(self, seed: int): # support mixing multiple datasets
|
||||
self.seed = seed
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||
try:
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception:
|
||||
if self.dataset is not None:
|
||||
raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
|
||||
dataset_info = None
|
||||
|
||||
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
||||
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
|
||||
@@ -128,7 +158,12 @@ class DataArguments:
|
||||
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
||||
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
|
||||
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
|
||||
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
|
||||
|
||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
||||
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
||||
dataset_attr.system_prompt = prompt_list[i]
|
||||
self.dataset_list.append(dataset_attr)
|
||||
|
||||
55
src/llmtuner/hparams/evaluation_args.py
Normal file
55
src/llmtuner/hparams/evaluation_args.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from datasets import DownloadMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationArguments:
|
||||
r"""
|
||||
Arguments pertaining to specify the evaluation parameters.
|
||||
"""
|
||||
task: str = field(
|
||||
metadata={"help": "Name of the evaluation task."}
|
||||
)
|
||||
task_dir: Optional[str] = field(
|
||||
default="evaluation",
|
||||
metadata={"help": "Path to the folder containing the evaluation datasets."}
|
||||
)
|
||||
batch_size: Optional[int] = field(
|
||||
default=4,
|
||||
metadata={"help": "The batch size per GPU for evaluation."}
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=42,
|
||||
metadata={"help": "Random seed to be used with data loaders."}
|
||||
)
|
||||
lang: Optional[Literal["en", "zh"]] = field(
|
||||
default="en",
|
||||
metadata={"help": "Language used at evaluation."}
|
||||
)
|
||||
n_shot: Optional[int] = field(
|
||||
default=5,
|
||||
metadata={"help": "Number of examplars for few-shot learning."}
|
||||
)
|
||||
save_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save the evaluation results."}
|
||||
)
|
||||
download_mode: Optional[DownloadMode] = field(
|
||||
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
||||
metadata={"help": "Download mode used for the evaluation datasets."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
task_available = []
|
||||
for folder in os.listdir(self.task_dir):
|
||||
if os.path.isdir(os.path.join(self.task_dir, folder)):
|
||||
task_available.append(folder)
|
||||
|
||||
if self.task not in task_available:
|
||||
raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
|
||||
|
||||
if self.save_dir is not None and os.path.exists(self.save_dir):
|
||||
raise ValueError("`save_dir` already exists, use another one.")
|
||||
@@ -8,22 +8,14 @@ class FinetuningArguments:
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."}
|
||||
)
|
||||
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."}
|
||||
)
|
||||
num_hidden_layers: Optional[int] = field(
|
||||
default=32,
|
||||
metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
|
||||
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
||||
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
|
||||
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
||||
Falcon choices: [\"32\", \"60\"], \
|
||||
Baichuan choices: [\"32\", \"40\"] \
|
||||
Qwen choices: [\"32\"], \
|
||||
XVERSE choices: [\"40\"], \
|
||||
ChatGLM2 choices: [\"28\"]"}
|
||||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
||||
@@ -32,10 +24,10 @@ class FinetuningArguments:
|
||||
default="mlp",
|
||||
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \
|
||||
Baichuan choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
|
||||
Qwen choices: [\"mlp\", \"attn\"], \
|
||||
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
|
||||
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
|
||||
LLaMA-2, BlueLM, Baichuan, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
|
||||
)
|
||||
lora_rank: Optional[int] = field(
|
||||
default=8,
|
||||
@@ -53,10 +45,15 @@ class FinetuningArguments:
|
||||
default=None,
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
|
||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
||||
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
|
||||
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
||||
LLaMA-2, BlueLM, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
|
||||
)
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
|
||||
)
|
||||
resume_lora_training: Optional[bool] = field(
|
||||
default=True,
|
||||
@@ -64,25 +61,45 @@ class FinetuningArguments:
|
||||
)
|
||||
ppo_score_norm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Use score normalization in PPO Training."}
|
||||
metadata={"help": "Use score normalization in PPO training."}
|
||||
)
|
||||
ppo_logger: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
|
||||
)
|
||||
ppo_target: Optional[float] = field(
|
||||
default=6.0,
|
||||
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
||||
)
|
||||
dpo_beta: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."}
|
||||
)
|
||||
dpo_ref_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reference model used for the DPO training."}
|
||||
)
|
||||
dpo_ref_model_checkpoint: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
|
||||
)
|
||||
upcast_layernorm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||
)
|
||||
neft_alpha: Optional[float] = field(
|
||||
default=0,
|
||||
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
||||
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
|
||||
|
||||
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
||||
if isinstance(self.additional_target, str):
|
||||
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
|
||||
|
||||
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralArguments:
|
||||
r"""
|
||||
Arguments pertaining to which stage we are going to perform.
|
||||
"""
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."}
|
||||
)
|
||||
@@ -28,7 +28,7 @@ class GeneratingArguments:
|
||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
default=None,
|
||||
default=512,
|
||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(
|
||||
@@ -46,6 +46,8 @@ class GeneratingArguments:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
args = asdict(self)
|
||||
if args.get("max_new_tokens", None):
|
||||
if args.get("max_new_tokens", -1) > 0:
|
||||
args.pop("max_length", None)
|
||||
else:
|
||||
args.pop("max_new_tokens", None)
|
||||
return args
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -19,9 +18,9 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
||||
)
|
||||
use_auth_token: Optional[bool] = field(
|
||||
split_special_tokens: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
|
||||
)
|
||||
model_revision: Optional[str] = field(
|
||||
default="main",
|
||||
@@ -43,38 +42,41 @@ class ModelArguments:
|
||||
default=None,
|
||||
metadata={"help": "Adopt scaled rotary positional embeddings."}
|
||||
)
|
||||
flash_attn: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable flash attention for faster training."}
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
flash_attn: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
||||
)
|
||||
shift_attn: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||
)
|
||||
reward_model: Optional[str] = field( # TODO: move it to FinetuningArguments
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field(
|
||||
plot_loss: Optional[bool] = field( # TODO: move it to FinetuningArguments
|
||||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
)
|
||||
hf_auth_token: Optional[str] = field(
|
||||
hf_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||
)
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
export_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
||||
)
|
||||
model_max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Used in rope scaling. Do not specify this argument manually."}
|
||||
metadata={"help": "Path to the directory to save the exported model."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.compute_dtype is not None or self.model_max_length is not None:
|
||||
raise ValueError("These arguments cannot be specified.")
|
||||
self.compute_dtype = None
|
||||
self.model_max_length = None
|
||||
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
@@ -82,6 +84,5 @@ class ModelArguments:
|
||||
if self.quantization_bit is not None:
|
||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
if self.use_auth_token == True and self.hf_auth_token is not None:
|
||||
from huggingface_hub.hf_api import HfFolder # lazy load
|
||||
HfFolder.save_token(self.hf_auth_token)
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
|
||||
from llmtuner.tuner.core.loader import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core.utils import generate_model_card
|
||||
|
||||
@@ -2,13 +2,14 @@ import os
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import cached_file
|
||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from peft import (
|
||||
PeftModel,
|
||||
TaskType,
|
||||
LoraConfig,
|
||||
get_peft_model
|
||||
)
|
||||
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.utils import find_all_linear_modules
|
||||
@@ -25,8 +26,7 @@ def init_adapter(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
is_mergeable: bool
|
||||
is_trainable: bool
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
@@ -36,34 +36,36 @@ def init_adapter(
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
|
||||
if finetuning_args.finetuning_type == "none" and is_trainable:
|
||||
raise ValueError("You cannot use finetuning_type=none while training.")
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.info("Checkpoint is not found at evaluation, load the original model.")
|
||||
return model
|
||||
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
model = model.float()
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze":
|
||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
num_layers = getattr(model.config, "num_layers")
|
||||
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
|
||||
|
||||
trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids]
|
||||
for name, param in model.named_parameters():
|
||||
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
|
||||
if not any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
param.requires_grad_(False)
|
||||
else:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
latest_checkpoint = None
|
||||
checkpoint_to_resume = None
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
|
||||
"Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
|
||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
||||
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||
|
||||
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
|
||||
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
if is_trainable and finetuning_args.resume_lora_training:
|
||||
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
else:
|
||||
checkpoints_to_merge = model_args.checkpoint_dir
|
||||
|
||||
@@ -74,10 +76,10 @@ def init_adapter(
|
||||
if len(checkpoints_to_merge) > 0:
|
||||
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
||||
|
||||
if latest_checkpoint is not None: # resume lora training or quantized inference
|
||||
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
|
||||
if checkpoint_to_resume is not None: # resume lora training
|
||||
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
|
||||
|
||||
if is_trainable and latest_checkpoint is None: # create new lora weights while training
|
||||
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
|
||||
else:
|
||||
@@ -89,7 +91,8 @@ def init_adapter(
|
||||
r=finetuning_args.lora_rank,
|
||||
lora_alpha=finetuning_args.lora_alpha,
|
||||
lora_dropout=finetuning_args.lora_dropout,
|
||||
target_modules=target_modules
|
||||
target_modules=target_modules,
|
||||
modules_to_save=finetuning_args.additional_target
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
@@ -97,3 +100,30 @@ def init_adapter(
|
||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_valuehead_params(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments"
|
||||
) -> bool:
|
||||
kwargs = {
|
||||
"path_or_repo_id": model_args.reward_model,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"token": model_args.hf_hub_token,
|
||||
"revision": model_args.model_revision
|
||||
}
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
try:
|
||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
|
||||
return False
|
||||
|
||||
vhead_params = torch.load(vhead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||
return True
|
||||
|
||||
@@ -4,7 +4,6 @@ import torch
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -14,20 +13,21 @@ from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase
|
||||
)
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.models.llama import modeling_llama as LlamaModule
|
||||
from transformers.utils.versions import require_version
|
||||
from peft import PeftModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
try:
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
except ImportError:
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters
|
||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.tuner.core.adapter import init_adapter
|
||||
from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params
|
||||
from llmtuner.tuner.core.utils import prepare_model_for_training
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -38,11 +38,11 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
check_min_version("4.30.0")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
|
||||
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft==0.4.0", "To fix: pip install peft==0.4.0")
|
||||
require_version("trl>=0.7.1", "To fix: pip install trl>=0.7.1")
|
||||
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
|
||||
require_version("trl>=0.7.4", "To fix: pip install trl>=0.7.4")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
@@ -56,28 +56,22 @@ def load_model_and_tokenizer(
|
||||
|
||||
Support both training and inference.
|
||||
"""
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
"token": model_args.hf_hub_token
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
split_special_tokens=model_args.split_special_tokens,
|
||||
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
# Fix tokenizer (for ChatGLM2)
|
||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
@@ -85,29 +79,28 @@ def load_model_and_tokenizer(
|
||||
|
||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||
|
||||
# Fix tokenizer (for ChatGLM2 and ChatGLM3)
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
# Set model dtype
|
||||
if model_args.compute_dtype is not None: # for training
|
||||
setattr(config, "torch_dtype", model_args.compute_dtype)
|
||||
else: # for evaluation, priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
# Fix config (for Qwen)
|
||||
if is_trainable and hasattr(config, "fp16") and hasattr(config, "bf16"):
|
||||
if model_args.compute_dtype == torch.bfloat16:
|
||||
setattr(config, "bf16", True)
|
||||
else:
|
||||
setattr(config, "fp16", True)
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
||||
|
||||
# Set RoPE scaling
|
||||
if model_args.rope_scaling is not None:
|
||||
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
|
||||
if is_trainable:
|
||||
logger.warning("Qwen model does not support RoPE scaling in training.")
|
||||
else:
|
||||
setattr(config, "use_dynamic_ntk", True)
|
||||
setattr(config, "use_logn_attn", True)
|
||||
logger.info("Using dynamic NTK scaling.")
|
||||
|
||||
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
|
||||
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
|
||||
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
else:
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
assert not model_args.flash_attn, "Flash attention does not support dynamic rope scaling."
|
||||
logger.warning(
|
||||
"Dynamic NTK may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
@@ -127,23 +120,32 @@ def load_model_and_tokenizer(
|
||||
model_args.rope_scaling, scaling_factor
|
||||
))
|
||||
|
||||
# Set FlashAttention-2
|
||||
if model_args.flash_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
if LlamaPatches.is_flash_attn_2_available:
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
else:
|
||||
logger.warning("FlashAttention-2 is not installed.")
|
||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||
else:
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
logger.warning("Current model does not support FlashAttention-2.")
|
||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
||||
|
||||
# Set flash attention
|
||||
if model_args.flash_attn and getattr(config, "model_type", None) == "llama":
|
||||
import transformers.models.llama.modeling_llama as LlamaModule
|
||||
from llmtuner.extras.models.flash_llama import LlamaRMSNorm, LlamaAttention, _prepare_decoder_attention_mask
|
||||
LlamaModule.LlamaRMSNorm = LlamaRMSNorm
|
||||
LlamaModule.LlamaAttention = LlamaAttention
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
|
||||
if not hasattr(config, "num_key_value_heads"):
|
||||
setattr(config, "num_key_value_heads", getattr(config, "num_attention_heads"))
|
||||
if getattr(config, "pretraining_tp", 1) != 1:
|
||||
setattr(config, "pretraining_tp", 1)
|
||||
# Set shift short attention (S^2-Attn)
|
||||
if is_trainable and model_args.shift_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
is_mergeable = True
|
||||
if model_args.quantization_bit is not None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
@@ -153,7 +155,7 @@ def load_model_and_tokenizer(
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
if model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
@@ -163,7 +165,6 @@ def load_model_and_tokenizer(
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
is_mergeable = False
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
@@ -176,13 +177,14 @@ def load_model_and_tokenizer(
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
# Disable custom generate method (for Qwen)
|
||||
if "GenerationMixin" not in str(model.generate.__func__):
|
||||
# Disable custom generate method (for Qwen and Baichuan2)
|
||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
# Fix LM head (for ChatGLM2)
|
||||
if not hasattr(model, "lm_head") and hasattr(model, "transformer"):
|
||||
# Fix LM head (for ChatGLM2 and ChatGLM3)
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||
|
||||
# Register auto class to save the custom code files.
|
||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
@@ -193,18 +195,17 @@ def load_model_and_tokenizer(
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
|
||||
# Initialize adapters
|
||||
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
model = model.train() if is_trainable else model.eval()
|
||||
|
||||
# Prepare model with valuehead for RLHF
|
||||
if stage == "rm" or stage == "ppo":
|
||||
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
model._keys_to_ignore_on_save = None
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
reset_logging()
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
||||
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
||||
if load_valuehead_params(model, model_args):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
@@ -212,18 +213,24 @@ def load_model_and_tokenizer(
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
||||
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
||||
if isinstance(model.pretrained_model, PeftModel):
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||
if "default" in name:
|
||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
|
||||
|
||||
# Prepare model for inference
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
|
||||
model = model.to(infer_dtype) if model_args.quantization_bit is None else model
|
||||
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
))
|
||||
|
||||
if not is_trainable:
|
||||
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
@@ -1,37 +1,24 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import datasets
|
||||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import parse_args
|
||||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
GeneratingArguments
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
return parser.parse_args_into_dataclasses()
|
||||
|
||||
|
||||
def parse_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[
|
||||
@@ -39,18 +26,16 @@ def parse_train_args(
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
GeneratingArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
GeneratingArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def parse_infer_args(
|
||||
@@ -67,7 +52,7 @@ def parse_infer_args(
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
return parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(
|
||||
@@ -77,10 +62,9 @@ def get_train_args(
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
GeneratingArguments
|
||||
]:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
@@ -93,47 +77,43 @@ def get_train_args(
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||
data_args.init_for_training()
|
||||
# Check arguments
|
||||
data_args.init_for_training(training_args.seed)
|
||||
|
||||
if general_args.stage != "pt" and data_args.template is None:
|
||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if general_args.stage != "sft" and training_args.predict_with_generate:
|
||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
||||
if finetuning_args.stage in ["rm", "ppo"]:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||
if training_args.load_best_model_at_end:
|
||||
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||
|
||||
if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
||||
raise ValueError("PPO training does not support evaluation.")
|
||||
|
||||
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
||||
raise ValueError("PPO and DPO stages can only be performed at training.")
|
||||
|
||||
if general_args.stage in ["rm", "dpo"]:
|
||||
if finetuning_args.stage in ["rm", "dpo"]:
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
if not dataset_attr.ranking:
|
||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
||||
|
||||
if general_args.stage == "ppo" and model_args.reward_model is None:
|
||||
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
if general_args.stage == "ppo" and training_args.deepspeed is not None:
|
||||
raise ValueError("PPO training is incompatible with DeepSpeed, use Accelerate instead.")
|
||||
|
||||
if general_args.stage == "ppo" and data_args.streaming:
|
||||
raise ValueError("Streaming mode does not suppport PPO training currently.")
|
||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
|
||||
raise ValueError("Streaming mode should have an integer val size.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
@@ -143,23 +123,21 @@ def get_train_args(
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
if len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Quantized model only accepts a single checkpoint.")
|
||||
if (
|
||||
model_args.checkpoint_dir is not None
|
||||
and len(model_args.checkpoint_dir) != 1
|
||||
and finetuning_args.finetuning_type != "lora"
|
||||
):
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
|
||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||
logger.warning("We recommend enable mixed precision training.")
|
||||
|
||||
# postprocess data_args
|
||||
if data_args.max_samples is not None and data_args.streaming:
|
||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
||||
data_args.max_samples = None
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
# postprocess training_args
|
||||
if (
|
||||
@@ -178,10 +156,9 @@ def get_train_args(
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
|
||||
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args_dict = training_args.to_dict()
|
||||
@@ -192,14 +169,10 @@ def get_train_args(
|
||||
)
|
||||
|
||||
# postprocess model_args
|
||||
if training_args.bf16:
|
||||
if not torch.cuda.is_bf16_supported():
|
||||
raise ValueError("Current device does not support bf16 training.")
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
model_args.compute_dtype = torch.float16
|
||||
|
||||
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
|
||||
model_args.compute_dtype = (
|
||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||
)
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
||||
@@ -211,7 +184,7 @@ def get_train_args(
|
||||
# Set seed before initializing model.
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args, general_args
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(
|
||||
@@ -230,11 +203,11 @@ def get_infer_args(
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
if len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||
raise ValueError("Quantized model only accepts a single checkpoint.")
|
||||
if (
|
||||
model_args.checkpoint_dir is not None
|
||||
and len(model_args.checkpoint_dir) != 1
|
||||
and finetuning_args.finetuning_type != "lora"
|
||||
):
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def find_all_linear_modules(
|
||||
model: "PreTrainedModel",
|
||||
quantization_bit: Optional[int] = None,
|
||||
output_layer_name: Optional[str] = "lm_head"
|
||||
quantization_bit: Optional[int] = None
|
||||
) -> List[str]:
|
||||
if quantization_bit is not None:
|
||||
import bitsandbytes as bnb
|
||||
@@ -18,23 +22,41 @@ def find_all_linear_modules(
|
||||
else:
|
||||
linear_cls = torch.nn.Linear
|
||||
|
||||
output_layer_names = ["lm_head"]
|
||||
if model.config.model_type == "chatglm":
|
||||
output_layer_names.append("output_layer")
|
||||
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if output_layer_name not in name and isinstance(module, linear_cls):
|
||||
if (
|
||||
isinstance(module, linear_cls)
|
||||
and not any([output_layer in name for output_layer in output_layer_names])
|
||||
):
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
if output_layer_name in module_names:
|
||||
module_names.pop(output_layer_name)
|
||||
|
||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def generate_model_card(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments"
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"tasks": "text-generation",
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
||||
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
|
||||
}
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel",
|
||||
finetuning_type: str,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
output_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||
layernorm_names: Optional[Set[str]] = LAYERNORM_NAMES
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Includes:
|
||||
@@ -43,30 +65,43 @@ def prepare_model_for_training(
|
||||
(3) upcast the lm_head to fp32
|
||||
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
|
||||
"""
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
if finetuning_args.upcast_layernorm:
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
logger.info("Upcasting weights in layernorm in float32.")
|
||||
|
||||
if finetuning_args.neft_alpha > 1e-6:
|
||||
def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
if module.training:
|
||||
dims = torch.tensor(output.size(1) * output.size(2))
|
||||
mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims)
|
||||
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
|
||||
return output
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
|
||||
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
def make_inputs_require_grad(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
output.requires_grad_(True)
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
||||
input_dtype = output_layer.weight.dtype
|
||||
|
||||
class CastOutputToFloat(torch.nn.Sequential):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||
|
||||
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
|
||||
if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||
output_layer = getattr(model, output_layer_name)
|
||||
if isinstance(output_layer, torch.nn.Linear):
|
||||
def fp32_forward_pre_hook(module: torch.nn.Module, args: Tuple[torch.Tensor]):
|
||||
return args[0].to(output_layer.weight.dtype)
|
||||
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
return output.to(torch.float32)
|
||||
output_layer.register_forward_pre_hook(fp32_forward_pre_hook)
|
||||
output_layer.register_forward_hook(fp32_forward_post_hook)
|
||||
|
||||
return model
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
from transformers import BatchEncoding, Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
@@ -19,6 +19,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
|
||||
**kwargs
|
||||
):
|
||||
if disable_dropout:
|
||||
@@ -29,9 +30,11 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||
self.ref_model = ref_model
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
self.padding_value = 0
|
||||
self.beta = beta
|
||||
self.loss_type = loss_type
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
@@ -40,8 +43,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
if ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model, = self.accelerator._prepare_deepspeed(self.ref_model)
|
||||
self.ref_model.eval()
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
|
||||
@@ -1,20 +1,24 @@
|
||||
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
||||
|
||||
from copy import deepcopy
|
||||
from peft import PeftModel
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.hparams import ModelArguments
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
||||
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.hparams import DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_dpo(
|
||||
@@ -29,9 +33,30 @@ def run_dpo(
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = DPODataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=4,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
# Create reference model
|
||||
if finetuning_args.dpo_ref_model is not None:
|
||||
ref_model_args_dict = model_args.to_dict()
|
||||
ref_model_args_dict.update(dict(
|
||||
model_name_or_path=finetuning_args.dpo_ref_model,
|
||||
checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint
|
||||
))
|
||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||
ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||
logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model))
|
||||
elif training_args.do_train:
|
||||
if isinstance(model, PeftModel):
|
||||
ref_model = None
|
||||
else:
|
||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||
logger.info("Created reference model from the model itself.")
|
||||
else:
|
||||
ref_model = model
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
@@ -40,7 +65,7 @@ def run_dpo(
|
||||
trainer = CustomDPOTrainer(
|
||||
beta=finetuning_args.dpo_beta,
|
||||
model=model,
|
||||
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
@@ -51,9 +76,27 @@ def run_dpo(
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
||||
logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.")
|
||||
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
||||
for key in remove_keys:
|
||||
metrics.pop(key)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Create model card
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
else:
|
||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
|
||||
@@ -1,23 +1,25 @@
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from trl import PPOTrainer
|
||||
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||
from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from llmtuner.hparams import GeneratingArguments
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -30,34 +32,47 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_args: "ModelArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: List["TrainerCallback"],
|
||||
compute_dtype: torch.dtype,
|
||||
**kwargs
|
||||
):
|
||||
PPOTrainer.__init__(self, **kwargs)
|
||||
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
|
||||
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
|
||||
|
||||
self.args = training_args
|
||||
self.generating_args = generating_args
|
||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||
self.compute_dtype = compute_dtype
|
||||
self.model_args = model_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.generation_config = GenerationConfig(
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
**generating_args.to_dict()
|
||||
)
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
||||
if self.args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
|
||||
def ppo_train(self, max_target_length: int) -> None:
|
||||
def ppo_train(self) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
"""
|
||||
total_train_batch_size = (
|
||||
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
|
||||
)
|
||||
len_dataloader = len(self.dataloader)
|
||||
num_examples = len(self.dataset)
|
||||
num_train_epochs = self.args.num_train_epochs
|
||||
max_steps = math.ceil(num_train_epochs * len_dataloader)
|
||||
if self.args.max_steps > 0:
|
||||
num_examples = total_train_batch_size * self.args.max_steps
|
||||
num_train_epochs = sys.maxsize
|
||||
max_steps = self.args.max_steps
|
||||
steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps
|
||||
else:
|
||||
len_dataloader = len(self.dataloader)
|
||||
num_examples = len(self.dataset)
|
||||
num_train_epochs = self.args.num_train_epochs
|
||||
max_steps = math.ceil(num_train_epochs * len_dataloader)
|
||||
steps_in_epoch = len_dataloader
|
||||
|
||||
self.state.max_steps = max_steps
|
||||
self.state.num_train_epochs = num_train_epochs
|
||||
@@ -74,25 +89,18 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
generating_args = self.generating_args.to_dict()
|
||||
generating_args.update(dict(
|
||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
pad_token_id=self.tokenizer.pad_token_id
|
||||
))
|
||||
|
||||
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
loss_meter = AverageMeter()
|
||||
reward_meter = AverageMeter()
|
||||
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
||||
|
||||
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
|
||||
batch = next(dataiter)
|
||||
steps_trained += 1
|
||||
try:
|
||||
batch = next(dataiter)
|
||||
except StopIteration:
|
||||
dataiter = iter(self.dataloader)
|
||||
batch = next(dataiter)
|
||||
|
||||
# Cast to inference mode
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
@@ -100,9 +108,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.model.eval()
|
||||
|
||||
# Get inputs
|
||||
queries, responses = self.get_inputs(batch, length_sampler, generating_args)
|
||||
self.tokenizer.padding_side = "right" # change padding side
|
||||
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
||||
queries, responses, rewards = [], [], []
|
||||
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||
mini_batch_queries, mini_batch_responses = self.get_inputs(batch[idx:idx+self.config.mini_batch_size])
|
||||
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses, unwrapped_model)
|
||||
queries.extend(mini_batch_queries)
|
||||
responses.extend(mini_batch_responses)
|
||||
rewards.extend(mini_batch_rewards)
|
||||
|
||||
# Cast to training mode
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
@@ -112,9 +125,17 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
# Run PPO step
|
||||
stats = self.step(queries, responses, rewards)
|
||||
self.tokenizer.padding_side = "left" # restore padding side
|
||||
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
||||
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
|
||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||
|
||||
if self.config.log_with is not None:
|
||||
try:
|
||||
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
|
||||
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
||||
self.log_stats(stats, batch, rewards)
|
||||
except:
|
||||
logger.warning("Failed to save stats due to unknown errors.")
|
||||
|
||||
self.state.global_step += 1
|
||||
self.log_callback.on_step_end(self.args, self.state, self.control)
|
||||
|
||||
@@ -123,7 +144,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
loss=round(loss_meter.avg, 4),
|
||||
reward=round(reward_meter.avg, 4),
|
||||
learning_rate=stats["ppo/learning_rate"],
|
||||
epoch=round(step / len_dataloader, 2)
|
||||
epoch=round(step / steps_in_epoch, 2)
|
||||
)
|
||||
tqdm.write(str(logs))
|
||||
logs["step"] = step
|
||||
@@ -143,49 +164,39 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
break
|
||||
|
||||
if steps_trained == len_dataloader:
|
||||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
|
||||
self.log_callback.on_train_end(
|
||||
self.log_callback.on_train_end(self.args, self.state, self.control)
|
||||
self.save_callback.on_train_end(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_inputs(
|
||||
self,
|
||||
batch: Dict[str, torch.Tensor],
|
||||
length_sampler: Callable,
|
||||
generating_args: Dict[str, Any]
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
generating_args["max_new_tokens"] = length_sampler()
|
||||
gen_kwargs = dict(
|
||||
generation_config=GenerationConfig(**generating_args),
|
||||
if self.finetuning_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(self.model)
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
response: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config,
|
||||
logits_processor=get_logits_processor(),
|
||||
**batch
|
||||
)
|
||||
|
||||
input_ids = batch["input_ids"]
|
||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
response: torch.Tensor = unwrapped_model.generate(**gen_kwargs)
|
||||
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
|
||||
query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu()
|
||||
if self.finetuning_args.upcast_layernorm:
|
||||
restore_layernorm(self.model, layernorm_params)
|
||||
|
||||
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
queries, responses = [], []
|
||||
for i in range(len(query)):
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
|
||||
if len(response_index) == 0:
|
||||
response_length = 1 # allow empty response
|
||||
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
||||
response_length = response_index[-1] + 2 # save the EOS token
|
||||
else:
|
||||
response_length = response_index[-1] + 1
|
||||
response_length = response_index[-1].item() + 1
|
||||
|
||||
queries.append(query[i, query_length:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
@@ -205,7 +216,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
|
||||
|
||||
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
||||
@@ -213,13 +224,14 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
rewards = []
|
||||
for i in range(values.size(0)):
|
||||
end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token
|
||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.pad_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
replace_model(unwrapped_model, target="default")
|
||||
return rewards
|
||||
|
||||
@PPODecorators.empty_cuda_cache()
|
||||
@PPODecorators.empty_device_cache()
|
||||
def batched_forward_pass(
|
||||
self,
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
@@ -250,7 +262,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
input_ids = input_kwargs["input_ids"]
|
||||
attention_mask = input_kwargs["attention_mask"]
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
logits, _, values = model(**input_kwargs)
|
||||
|
||||
if values.size(0) != input_ids.size(0): # adapt to chatglm2
|
||||
@@ -263,7 +275,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
for j in range(len(query_batch)):
|
||||
start = len(query_batch[j]) - 1
|
||||
if attention_mask[j, 0] == 0: # offset left padding
|
||||
start += attention_mask[j, :].nonzero()[0]
|
||||
start += attention_mask[j, :].nonzero()[0].item()
|
||||
end = start + len(response_batch[j])
|
||||
|
||||
if response_masks is not None:
|
||||
|
||||
@@ -1,40 +1,35 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save default head temporarily
|
||||
valuehead_state_dict = model.v_head.state_dict()
|
||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
|
||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"])
|
||||
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
|
||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
|
||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
|
||||
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
||||
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
||||
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
|
||||
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
|
||||
})
|
||||
|
||||
|
||||
def cast_layernorm_dtype(
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
compute_dtype: torch.dtype,
|
||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None,
|
||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
|
||||
|
||||
layer_norm_state_dict = {}
|
||||
|
||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
||||
layer_norm_params = {}
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
if layer_norm_params is None:
|
||||
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
||||
param.data = param.data.to(compute_dtype)
|
||||
else:
|
||||
param.data = layer_norm_params[name] # restore float32 weights
|
||||
if param.data.dtype == torch.float32:
|
||||
layer_norm_params[name] = param.data.detach().clone()
|
||||
param.data = param.data.to(model.config.torch_dtype)
|
||||
|
||||
return model, layer_norm_state_dict
|
||||
return layer_norm_params
|
||||
|
||||
|
||||
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||
for name, param in model.named_parameters():
|
||||
if name in layernorm_params:
|
||||
param.data = layernorm_params[name]
|
||||
|
||||
@@ -42,18 +42,23 @@ def run_ppo(
|
||||
ppo_epochs=1,
|
||||
max_grad_norm=training_args.max_grad_norm,
|
||||
seed=training_args.seed,
|
||||
optimize_cuda_cache=True
|
||||
optimize_device_cache=True,
|
||||
target=finetuning_args.ppo_target,
|
||||
log_with=finetuning_args.ppo_logger,
|
||||
use_score_scaling=finetuning_args.ppo_score_norm,
|
||||
use_score_norm=finetuning_args.ppo_score_norm,
|
||||
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
||||
)
|
||||
|
||||
if finetuning_args.ppo_score_norm:
|
||||
ppo_config.use_score_scaling = True
|
||||
ppo_config.use_score_norm = True
|
||||
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
total_train_batch_size = (
|
||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||
)
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||
if training_args.max_steps > 0:
|
||||
num_training_steps = training_args.max_steps
|
||||
else:
|
||||
total_train_batch_size = (
|
||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||
)
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
@@ -63,10 +68,11 @@ def run_ppo(
|
||||
|
||||
# Initialize our Trainer
|
||||
ppo_trainer = CustomPPOTrainer(
|
||||
model_args=model_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
generating_args=generating_args,
|
||||
callbacks=callbacks + [SavePeftModelCallback()],
|
||||
compute_dtype=model_args.compute_dtype,
|
||||
config=ppo_config,
|
||||
model=model,
|
||||
ref_model=None,
|
||||
@@ -79,7 +85,7 @@ def run_ppo(
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
||||
ppo_trainer.ppo_train()
|
||||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
@@ -6,7 +6,7 @@ from transformers import DataCollatorForLanguageModeling, Trainer
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
@@ -38,10 +38,10 @@ def run_pt(
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
@@ -56,3 +56,10 @@ def run_pt(
|
||||
metrics["perplexity"] = perplexity
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Create model card
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
else:
|
||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
|
||||
@@ -34,7 +34,7 @@ class PairwiseTrainer(Trainer):
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Note that the first element will be removed from the output tuple.
|
||||
Note that the first element will be removed from the output tuple.
|
||||
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
|
||||
"""
|
||||
# Compute rewards
|
||||
@@ -45,9 +45,6 @@ class PairwiseTrainer(Trainer):
|
||||
# Split the inputs and rewards into two parts, chosen and rejected
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
|
||||
chosen_attn_mask, rejected_attn_mask = (
|
||||
inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:]
|
||||
)
|
||||
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
|
||||
chosen_scores, rejected_scores = [], []
|
||||
|
||||
@@ -55,8 +52,8 @@ class PairwiseTrainer(Trainer):
|
||||
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
|
||||
loss = 0
|
||||
for i in range(batch_size):
|
||||
chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1
|
||||
rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1
|
||||
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
|
||||
|
||||
if len(check_divergence) == 0:
|
||||
@@ -69,7 +66,7 @@ class PairwiseTrainer(Trainer):
|
||||
assert div_index > 0
|
||||
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
||||
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
||||
if return_outputs: # use the score on the EOS token for inference
|
||||
if return_outputs: # use the score on the last token except pad token for inference
|
||||
chosen_scores.append(chosen_rewards[i, chosen_length-1])
|
||||
rejected_scores.append(rejected_rewards[i, rejected_length-1])
|
||||
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
||||
@@ -95,7 +92,6 @@ class PairwiseTrainer(Trainer):
|
||||
|
||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
||||
|
||||
chosen_scores, rejected_scores = predict_results.predictions
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Inspired by:
|
||||
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
@@ -7,7 +6,7 @@ from transformers import Seq2SeqTrainingArguments
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from llmtuner.tuner.rm.trainer import PairwiseTrainer
|
||||
@@ -27,8 +26,9 @@ def run_rm(
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
@@ -47,10 +47,10 @@ def run_rm(
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train()
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
@@ -66,3 +66,10 @@ def run_rm(
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results)
|
||||
|
||||
# Create model card
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
else:
|
||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
|
||||
@@ -33,28 +33,20 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
|
||||
if self.args.predict_with_generate:
|
||||
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
||||
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
|
||||
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
||||
if prompt_len > label_len:
|
||||
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
||||
if label_len > prompt_len:
|
||||
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
|
||||
if "attention_mask" in inputs:
|
||||
inputs["attention_mask"] = self._pad_tensors_to_target_len(
|
||||
inputs["attention_mask"], inputs["labels"], pad_token_id=0
|
||||
)
|
||||
if "position_ids" in inputs:
|
||||
inputs["position_ids"] = self._pad_tensors_to_target_len(
|
||||
inputs["position_ids"], inputs["labels"], pad_token_id=0
|
||||
)
|
||||
inputs["labels"] = inputs["labels"][:, :prompt_len] # truncate the labels instead of padding the inputs
|
||||
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
loss, generated_tokens, _ = super().prediction_step(
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||
)
|
||||
if generated_tokens is not None and self.args.predict_with_generate:
|
||||
generated_tokens[:, :max(prompt_len, label_len)] = self.tokenizer.pad_token_id
|
||||
generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
|
||||
generated_tokens = generated_tokens.contiguous()
|
||||
|
||||
return loss, generated_tokens, labels
|
||||
@@ -62,14 +54,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
def _pad_tensors_to_target_len(
|
||||
self,
|
||||
src_tensor: torch.Tensor,
|
||||
tgt_tensor: torch.Tensor,
|
||||
pad_token_id: Optional[int] = None
|
||||
tgt_tensor: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Pads the tensor to the same length as the target tensor.
|
||||
"""
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.tokenizer.pad_token_id
|
||||
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
||||
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
|
||||
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
|
||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
||||
return padded_tensor.contiguous() # in contiguous memory
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
||||
@@ -7,7 +7,7 @@ from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
|
||||
|
||||
@@ -33,13 +33,14 @@ def run_sft(
|
||||
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(
|
||||
generation_max_length=training_args.generation_max_length or data_args.max_target_length,
|
||||
generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
|
||||
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
|
||||
))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
@@ -64,10 +65,10 @@ def run_sft(
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
@@ -87,3 +88,10 @@ def run_sft(
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results)
|
||||
|
||||
# Create model card
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
else:
|
||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner.pt import run_pt
|
||||
from llmtuner.tuner.sft import run_sft
|
||||
from llmtuner.tuner.rm import run_rm
|
||||
@@ -17,29 +17,32 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||
model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
callbacks = [LogCallback()] if callbacks is None else callbacks
|
||||
|
||||
if general_args.stage == "pt":
|
||||
if finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif general_args.stage == "sft":
|
||||
elif finetuning_args.stage == "sft":
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif general_args.stage == "rm":
|
||||
elif finetuning_args.stage == "rm":
|
||||
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif general_args.stage == "ppo":
|
||||
elif finetuning_args.stage == "ppo":
|
||||
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif general_args.stage == "dpo":
|
||||
elif finetuning_args.stage == "dpo":
|
||||
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
raise ValueError("Unknown task.")
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
||||
model_args, _, training_args, finetuning_args, _, _ = get_train_args(args)
|
||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
|
||||
model.config.use_cache = True
|
||||
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
|
||||
try:
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
tokenizer.padding_side = "left" # restore padding side
|
||||
tokenizer.init_kwargs["padding_side"] = "left"
|
||||
tokenizer.save_pretrained(model_args.export_dir)
|
||||
except:
|
||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.hparams import GeneratingArguments
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None:
|
||||
if lazy_init:
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
else:
|
||||
super().__init__(args)
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
system_prompt: str
|
||||
):
|
||||
if self.model is not None:
|
||||
yield ALERTS["err_exists"][lang]
|
||||
return
|
||||
|
||||
if not model_name:
|
||||
yield ALERTS["err_no_model"][lang]
|
||||
return
|
||||
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
if not model_name_or_path:
|
||||
yield ALERTS["err_no_path"][lang]
|
||||
return
|
||||
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
)
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None,
|
||||
template=template,
|
||||
system_prompt=system_prompt
|
||||
)
|
||||
super().__init__(args)
|
||||
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
def unload_model(self, lang: str):
|
||||
yield ALERTS["info_unloading"][lang]
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
torch_gc()
|
||||
yield ALERTS["info_unloaded"][lang]
|
||||
|
||||
def predict(
|
||||
self,
|
||||
chatbot: List[Tuple[str, str]],
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
system: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float
|
||||
):
|
||||
chatbot.append([query, ""])
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
response = self.postprocess(response)
|
||||
new_history = history + [(query, response)]
|
||||
chatbot[-1] = [query, response]
|
||||
yield chatbot, new_history
|
||||
|
||||
def postprocess(self, response: str) -> str:
|
||||
blocks = response.split("```")
|
||||
for i, block in enumerate(blocks):
|
||||
if i % 2 == 0:
|
||||
blocks[i] = block.replace("<", "<").replace(">", ">")
|
||||
return "```".join(blocks)
|
||||
101
src/llmtuner/webui/chatter.py
Normal file
101
src/llmtuner/webui/chatter.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.hparams import GeneratingArguments
|
||||
from llmtuner.webui.common import get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.webui.manager import Manager
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
|
||||
def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None:
|
||||
self.manager = manager
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
if not lazy_init:
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
lang = get("top.lang")
|
||||
error = ""
|
||||
if self.loaded:
|
||||
error = ALERTS["err_exists"][lang]
|
||||
elif not get("top.model_name"):
|
||||
error = ALERTS["err_no_model"][lang]
|
||||
elif not get("top.model_path"):
|
||||
error = ALERTS["err_no_path"][lang]
|
||||
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error
|
||||
return
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=get("top.model_path"),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
system_prompt=get("top.system_prompt"),
|
||||
flash_attn=get("top.flash_attn"),
|
||||
shift_attn=get("top.shift_attn"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
|
||||
)
|
||||
super().__init__(args)
|
||||
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
lang = data[self.manager.get_elem_by_name("top.lang")]
|
||||
yield ALERTS["info_unloading"][lang]
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
torch_gc()
|
||||
yield ALERTS["info_unloaded"][lang]
|
||||
|
||||
def predict(
|
||||
self,
|
||||
chatbot: List[Tuple[str, str]],
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
system: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float
|
||||
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
|
||||
chatbot.append([query, ""])
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
new_history = history + [(query, response)]
|
||||
chatbot[-1] = [query, self.postprocess(response)]
|
||||
yield chatbot, new_history
|
||||
|
||||
def postprocess(self, response: str) -> str:
|
||||
blocks = response.split("```")
|
||||
for i, block in enumerate(blocks):
|
||||
if i % 2 == 0:
|
||||
blocks[i] = block.replace("<", "<").replace(">", ">")
|
||||
return "```".join(blocks)
|
||||
@@ -1,12 +1,17 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import json
|
||||
import gradio as gr
|
||||
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from typing import Any, Dict, Optional
|
||||
from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
)
|
||||
|
||||
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
|
||||
|
||||
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
@@ -14,6 +19,14 @@ DEFAULT_DATA_DIR = "data"
|
||||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user.config"
|
||||
DATA_CONFIG = "dataset_info.json"
|
||||
CKPT_NAMES = [
|
||||
WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
]
|
||||
|
||||
|
||||
def get_save_dir(*args) -> os.PathLike:
|
||||
@@ -32,7 +45,7 @@ def load_config() -> Dict[str, Any]:
|
||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
|
||||
|
||||
def save_config(lang: str, model_name: str, model_path: str) -> None:
|
||||
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
user_config = load_config()
|
||||
user_config["lang"] = lang or user_config["lang"]
|
||||
@@ -45,28 +58,34 @@ def save_config(lang: str, model_name: str, model_path: str) -> None:
|
||||
|
||||
def get_model_path(model_name: str) -> str:
|
||||
user_config = load_config()
|
||||
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
|
||||
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
|
||||
|
||||
|
||||
def get_prefix(model_name: str) -> str:
|
||||
return model_name.split("-")[0]
|
||||
|
||||
|
||||
def get_module(model_name: str) -> str:
|
||||
return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj")
|
||||
|
||||
|
||||
def get_template(model_name: str) -> str:
|
||||
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
|
||||
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
|
||||
if model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
|
||||
return DEFAULT_TEMPLATE[get_prefix(model_name)]
|
||||
return "default"
|
||||
|
||||
|
||||
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
checkpoints = []
|
||||
save_dir = get_save_dir(model_name, finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
for checkpoint in os.listdir(save_dir):
|
||||
if (
|
||||
os.path.isdir(os.path.join(save_dir, checkpoint))
|
||||
and any([
|
||||
os.path.isfile(os.path.join(save_dir, checkpoint, name))
|
||||
for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
|
||||
])
|
||||
):
|
||||
checkpoints.append(checkpoint)
|
||||
if model_name:
|
||||
save_dir = get_save_dir(model_name, finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
for checkpoint in os.listdir(save_dir):
|
||||
if (
|
||||
os.path.isdir(os.path.join(save_dir, checkpoint))
|
||||
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
|
||||
):
|
||||
checkpoints.append(checkpoint)
|
||||
return gr.update(value=[], choices=checkpoints)
|
||||
|
||||
|
||||
@@ -75,6 +94,7 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
print("Cannot find {} in {}.".format(DATA_CONFIG, dataset_dir))
|
||||
return {}
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_chat_box(
|
||||
chat_model: "WebChatModel",
|
||||
engine: "Engine",
|
||||
visible: Optional[bool] = False
|
||||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Box(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot()
|
||||
|
||||
history = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
system = gr.Textbox(show_label=False)
|
||||
@@ -23,14 +22,13 @@ def create_chat_box(
|
||||
|
||||
with gr.Column(scale=1):
|
||||
clear_btn = gr.Button()
|
||||
max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01)
|
||||
|
||||
history = gr.State([])
|
||||
gen_kwargs = engine.chatter.generating_args
|
||||
max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
||||
|
||||
submit_btn.click(
|
||||
chat_model.predict,
|
||||
engine.chatter.predict,
|
||||
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
|
||||
[chatbot, history],
|
||||
show_progress=True
|
||||
|
||||
@@ -1,21 +1,103 @@
|
||||
import os
|
||||
import json
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||
|
||||
from llmtuner.webui.common import DATA_CONFIG
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
|
||||
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
|
||||
PAGE_SIZE = 2
|
||||
|
||||
|
||||
def prev_page(page_index: int) -> int:
|
||||
return page_index - 1 if page_index > 0 else page_index
|
||||
|
||||
|
||||
def next_page(page_index: int, total_num: int) -> int:
|
||||
return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
if (
|
||||
len(dataset) > 0
|
||||
and "file_name" in dataset_info[dataset[0]]
|
||||
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
|
||||
):
|
||||
return gr.update(interactive=True)
|
||||
else:
|
||||
return gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, Dict[str, Any]]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
data_file: str = dataset_info[dataset[0]]["file_name"]
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
if data_file.endswith(".json"):
|
||||
data = json.load(f)
|
||||
elif data_file.endswith(".jsonl"):
|
||||
data = [json.loads(line) for line in f]
|
||||
else:
|
||||
data = [line for line in f]
|
||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
|
||||
|
||||
|
||||
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
|
||||
with gr.Row():
|
||||
preview_count = gr.Number(interactive=False)
|
||||
preview_count = gr.Number(value=0, interactive=False, precision=0)
|
||||
page_index = gr.Number(value=0, interactive=False, precision=0)
|
||||
|
||||
with gr.Row():
|
||||
prev_btn = gr.Button()
|
||||
next_btn = gr.Button()
|
||||
close_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
preview_samples = gr.JSON(interactive=False)
|
||||
|
||||
close_btn = gr.Button()
|
||||
|
||||
dataset.change(
|
||||
can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False
|
||||
).then(
|
||||
lambda: 0, outputs=[page_index], queue=False
|
||||
)
|
||||
data_preview_btn.click(
|
||||
get_preview,
|
||||
[dataset_dir, dataset, page_index],
|
||||
[preview_count, preview_samples, preview_box],
|
||||
queue=False
|
||||
)
|
||||
prev_btn.click(
|
||||
prev_page, [page_index], [page_index], queue=False
|
||||
).then(
|
||||
get_preview,
|
||||
[dataset_dir, dataset, page_index],
|
||||
[preview_count, preview_samples, preview_box],
|
||||
queue=False
|
||||
)
|
||||
next_btn.click(
|
||||
next_page, [page_index, preview_count], [page_index], queue=False
|
||||
).then(
|
||||
get_preview,
|
||||
[dataset_dir, dataset, page_index],
|
||||
[preview_count, preview_samples, preview_box],
|
||||
queue=False
|
||||
)
|
||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
||||
|
||||
return preview_box, preview_count, preview_samples, close_btn
|
||||
return dict(
|
||||
data_preview_btn=data_preview_btn,
|
||||
preview_count=preview_count,
|
||||
page_index=page_index,
|
||||
prev_btn=prev_btn,
|
||||
next_btn=next_btn,
|
||||
close_btn=close_btn,
|
||||
preview_samples=preview_samples
|
||||
)
|
||||
|
||||
@@ -1,90 +1,70 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.utils import can_preview, get_preview
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
|
||||
data_preview_btn.click(
|
||||
get_preview,
|
||||
[dataset_dir, dataset],
|
||||
[preview_count, preview_samples, preview_box],
|
||||
queue=False
|
||||
)
|
||||
input_elems.update({dataset_dir, dataset})
|
||||
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
with gr.Row():
|
||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
input_elems.update({cutoff_len, max_samples, batch_size, predict})
|
||||
elem_dict.update(dict(
|
||||
cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||
|
||||
input_elems.update({max_new_tokens, top_p, temperature})
|
||||
elem_dict.update(dict(
|
||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Box():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
input_components = [
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["system_prompt"],
|
||||
dataset_dir,
|
||||
dataset,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
max_samples,
|
||||
batch_size,
|
||||
predict
|
||||
]
|
||||
output_elems = [output_box, process_bar]
|
||||
elem_dict.update(dict(
|
||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
|
||||
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box
|
||||
))
|
||||
|
||||
output_components = [
|
||||
output_box,
|
||||
process_bar
|
||||
]
|
||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
|
||||
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||
|
||||
cmd_preview_btn.click(runner.preview_eval, input_components, output_components)
|
||||
start_btn.click(runner.run_eval, input_components, output_components)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
|
||||
return dict(
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
data_preview_btn=data_preview_btn,
|
||||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
max_samples=max_samples,
|
||||
batch_size=batch_size,
|
||||
predict=predict,
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_box=output_box
|
||||
)
|
||||
return elem_dict
|
||||
|
||||
@@ -1,15 +1,56 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List
|
||||
|
||||
from llmtuner.webui.utils import save_model
|
||||
from llmtuner.tuner import export_model
|
||||
from llmtuner.webui.common import get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
|
||||
def save_model(
|
||||
lang: str,
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
template: str,
|
||||
max_shard_size: int,
|
||||
export_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
error = ""
|
||||
if not model_name:
|
||||
error = ALERTS["err_no_model"][lang]
|
||||
elif not model_path:
|
||||
error = ALERTS["err_no_path"][lang]
|
||||
elif not checkpoints:
|
||||
error = ALERTS["err_no_checkpoint"][lang]
|
||||
elif not export_dir:
|
||||
error = ALERTS["err_no_export_dir"][lang]
|
||||
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error
|
||||
return
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_path,
|
||||
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
export_dir=export_dir
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
export_model(args, max_shard_size="{}GB".format(max_shard_size))
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
save_dir = gr.Textbox()
|
||||
export_dir = gr.Textbox()
|
||||
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
||||
|
||||
export_btn = gr.Button()
|
||||
@@ -18,19 +59,20 @@ def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component
|
||||
export_btn.click(
|
||||
save_model,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["template"],
|
||||
engine.manager.get_elem_by_name("top.lang"),
|
||||
engine.manager.get_elem_by_name("top.model_name"),
|
||||
engine.manager.get_elem_by_name("top.model_path"),
|
||||
engine.manager.get_elem_by_name("top.checkpoints"),
|
||||
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_name("top.template"),
|
||||
max_shard_size,
|
||||
save_dir
|
||||
export_dir
|
||||
],
|
||||
[info_box]
|
||||
)
|
||||
|
||||
return dict(
|
||||
save_dir=save_dir,
|
||||
export_dir=export_dir,
|
||||
max_shard_size=max_shard_size,
|
||||
export_btn=export_btn,
|
||||
info_box=info_box
|
||||
|
||||
@@ -1,51 +1,39 @@
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.components.chatbot import create_chat_box
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
|
||||
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
load_btn = gr.Button()
|
||||
unload_btn = gr.Button()
|
||||
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
||||
|
||||
chat_model = WebChatModel(lazy_init=True)
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||
|
||||
load_btn.click(
|
||||
chat_model.load_model,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["system_prompt"]
|
||||
],
|
||||
[info_box]
|
||||
engine.chatter.load_model, input_elems, [info_box]
|
||||
).then(
|
||||
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
|
||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
)
|
||||
|
||||
unload_btn.click(
|
||||
chat_model.unload_model, [top_elems["lang"]], [info_box]
|
||||
engine.chatter.unload_model, input_elems, [info_box]
|
||||
).then(
|
||||
lambda: ([], []), outputs=[chatbot, history]
|
||||
).then(
|
||||
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
|
||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
)
|
||||
|
||||
return dict(
|
||||
info_box=info_box,
|
||||
load_btn=load_btn,
|
||||
unload_btn=unload_btn,
|
||||
**chat_elems
|
||||
)
|
||||
return elem_dict
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from llmtuner.extras.template import templates
|
||||
from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config
|
||||
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
|
||||
from llmtuner.webui.utils import can_quantize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -26,26 +25,31 @@ def create_top() -> Dict[str, "Component"]:
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1)
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
|
||||
system_prompt = gr.Textbox(scale=2)
|
||||
|
||||
lang.change(save_config, [lang, model_name, model_path])
|
||||
with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
flash_attn = gr.Checkbox(value=False)
|
||||
shift_attn = gr.Checkbox(value=False)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
||||
|
||||
model_name.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||
).then(
|
||||
get_model_path, [model_name], [model_path]
|
||||
get_model_path, [model_name], [model_path], queue=False
|
||||
).then(
|
||||
get_template, [model_name], [template]
|
||||
get_template, [model_name], [template], queue=False
|
||||
) # do not save config since the below line will save
|
||||
|
||||
model_path.change(save_config, [lang, model_name, model_path])
|
||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
|
||||
finetuning_type.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||
).then(
|
||||
can_quantize, [finetuning_type], [quantization_bit]
|
||||
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
||||
)
|
||||
|
||||
refresh_btn.click(
|
||||
@@ -62,5 +66,9 @@ def create_top() -> Dict[str, "Component"]:
|
||||
advanced_tab=advanced_tab,
|
||||
quantization_bit=quantization_bit,
|
||||
template=template,
|
||||
system_prompt=system_prompt
|
||||
system_prompt=system_prompt,
|
||||
llama_tab=llama_tab,
|
||||
flash_attn=flash_attn,
|
||||
shift_attn=shift_attn,
|
||||
rope_scaling=rope_scaling
|
||||
)
|
||||
|
||||
@@ -1,45 +1,49 @@
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from llmtuner.extras.constants import TRAINING_STAGES
|
||||
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
||||
from llmtuner.webui.utils import gen_plot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
training_stage = gr.Dropdown(
|
||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
|
||||
)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset])
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
|
||||
data_preview_btn.click(
|
||||
get_preview,
|
||||
[dataset_dir, dataset],
|
||||
[preview_count, preview_samples, preview_box],
|
||||
queue=False
|
||||
)
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
elem_dict.update(dict(
|
||||
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
learning_rate = gr.Textbox(value="5e-5")
|
||||
num_train_epochs = gr.Textbox(value="3.0")
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
||||
|
||||
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
|
||||
elem_dict.update(dict(
|
||||
cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs,
|
||||
max_samples=max_samples, compute_type=compute_type
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
@@ -50,33 +54,59 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
||||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
|
||||
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
|
||||
elem_dict.update(dict(
|
||||
batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
|
||||
))
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
||||
neft_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
|
||||
|
||||
with gr.Column():
|
||||
train_on_prompt = gr.Checkbox(value=False)
|
||||
upcast_layernorm = gr.Checkbox(value=False)
|
||||
|
||||
input_elems.update({logging_steps, save_steps, warmup_steps, neft_alpha, train_on_prompt, upcast_layernorm})
|
||||
elem_dict.update(dict(
|
||||
advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
|
||||
neft_alpha=neft_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm
|
||||
))
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
lora_target = gr.Textbox(scale=1)
|
||||
additional_target = gr.Textbox(scale=1)
|
||||
resume_lora_training = gr.Checkbox(value=True, scale=1)
|
||||
|
||||
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, resume_lora_training})
|
||||
elem_dict.update(dict(
|
||||
lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target,
|
||||
additional_target=additional_target, resume_lora_training=resume_lora_training,
|
||||
))
|
||||
|
||||
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2)
|
||||
reward_model = gr.Dropdown(scale=2)
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
reward_model = gr.Dropdown(scale=3)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
refresh_btn.click(
|
||||
list_checkpoint,
|
||||
[top_elems["model_name"], top_elems["finetuning_type"]],
|
||||
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False
|
||||
)
|
||||
|
||||
input_elems.update({dpo_beta, reward_model})
|
||||
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn))
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button()
|
||||
@@ -88,6 +118,7 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Box():
|
||||
@@ -96,89 +127,28 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
|
||||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
input_components = [
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["system_prompt"],
|
||||
training_stage,
|
||||
dataset_dir,
|
||||
dataset,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
learning_rate,
|
||||
num_train_epochs,
|
||||
max_samples,
|
||||
batch_size,
|
||||
gradient_accumulation_steps,
|
||||
lr_scheduler_type,
|
||||
max_grad_norm,
|
||||
val_size,
|
||||
logging_steps,
|
||||
save_steps,
|
||||
warmup_steps,
|
||||
compute_type,
|
||||
lora_rank,
|
||||
lora_dropout,
|
||||
lora_target,
|
||||
resume_lora_training,
|
||||
dpo_beta,
|
||||
reward_model,
|
||||
output_dir
|
||||
]
|
||||
input_elems.add(output_dir)
|
||||
output_elems = [output_box, process_bar]
|
||||
|
||||
output_components = [
|
||||
output_box,
|
||||
process_bar
|
||||
]
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
|
||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||
|
||||
cmd_preview_btn.click(runner.preview_train, input_components, output_components)
|
||||
start_btn.click(runner.run_train, input_components, output_components)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
elem_dict.update(dict(
|
||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
|
||||
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
|
||||
))
|
||||
|
||||
process_bar.change(
|
||||
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
|
||||
output_box.change(
|
||||
gen_plot,
|
||||
[
|
||||
engine.manager.get_elem_by_name("top.model_name"),
|
||||
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||
output_dir
|
||||
],
|
||||
loss_viewer,
|
||||
queue=False
|
||||
)
|
||||
|
||||
return dict(
|
||||
training_stage=training_stage,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
data_preview_btn=data_preview_btn,
|
||||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
learning_rate=learning_rate,
|
||||
num_train_epochs=num_train_epochs,
|
||||
max_samples=max_samples,
|
||||
batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
max_grad_norm=max_grad_norm,
|
||||
val_size=val_size,
|
||||
advanced_tab=advanced_tab,
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
compute_type=compute_type,
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target,
|
||||
resume_lora_training=resume_lora_training,
|
||||
rlhf_tab=rlhf_tab,
|
||||
dpo_beta=dpo_beta,
|
||||
reward_model=reward_model,
|
||||
refresh_btn=refresh_btn,
|
||||
cmd_preview_btn=cmd_preview_btn,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_dir=output_dir,
|
||||
output_box=output_box,
|
||||
loss_viewer=loss_viewer
|
||||
)
|
||||
return elem_dict
|
||||
|
||||
@@ -6,10 +6,12 @@ CSS = r"""
|
||||
transform: translate(-50%, -50%); /* center horizontally */
|
||||
max-width: 1000px;
|
||||
max-height: 750px;
|
||||
overflow-y: scroll !important;
|
||||
overflow-y: auto;
|
||||
background-color: var(--input-background-fill);
|
||||
flex-wrap: nowrap !important;
|
||||
border: 2px solid black !important;
|
||||
z-index: 1000;
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.dark .modal-box {
|
||||
|
||||
57
src/llmtuner/webui/engine.py
Normal file
57
src/llmtuner/webui/engine.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from llmtuner.webui.chatter import WebChatModel
|
||||
from llmtuner.webui.common import get_model_path, list_dataset, load_config
|
||||
from llmtuner.webui.locales import LOCALES
|
||||
from llmtuner.webui.manager import Manager
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.utils import get_time
|
||||
|
||||
|
||||
class Engine:
|
||||
|
||||
def __init__(self, pure_chat: Optional[bool] = False) -> None:
|
||||
self.pure_chat = pure_chat
|
||||
self.manager: "Manager" = Manager()
|
||||
self.runner: "Runner" = Runner(self.manager)
|
||||
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
|
||||
|
||||
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
||||
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||
|
||||
def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
||||
user_config = load_config()
|
||||
lang = user_config.get("lang", None) or "en"
|
||||
|
||||
init_dict = {
|
||||
"top.lang": {"value": lang},
|
||||
"infer.chat_box": {"visible": self.chatter.loaded}
|
||||
}
|
||||
|
||||
if not self.pure_chat:
|
||||
init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
|
||||
init_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
|
||||
|
||||
if user_config.get("last_model", None):
|
||||
init_dict["top.model_name"] = {"value": user_config["last_model"]}
|
||||
init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
|
||||
|
||||
yield self._form_dict(init_dict)
|
||||
|
||||
if not self.pure_chat:
|
||||
if self.runner.alive:
|
||||
yield {elem: gr.update(value=value) for elem, value in self.runner.running_data.items()}
|
||||
if self.runner.do_train:
|
||||
yield self._form_dict({"train.resume_btn": {"value": True}})
|
||||
else:
|
||||
yield self._form_dict({"eval.resume_btn": {"value": True}})
|
||||
else:
|
||||
yield self._form_dict({"train.output_dir": {"value": get_time()}})
|
||||
|
||||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||
return {
|
||||
component: gr.update(**LOCALES[name][lang])
|
||||
for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES
|
||||
}
|
||||
@@ -9,65 +9,53 @@ from llmtuner.webui.components import (
|
||||
create_export_tab,
|
||||
create_chat_box
|
||||
)
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.common import save_config
|
||||
from llmtuner.webui.css import CSS
|
||||
from llmtuner.webui.manager import Manager
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
||||
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
|
||||
|
||||
|
||||
def create_ui() -> gr.Blocks:
|
||||
runner = Runner()
|
||||
engine = Engine(pure_chat=False)
|
||||
|
||||
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
||||
top_elems = create_top()
|
||||
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
||||
engine.manager.all_elems["top"] = create_top()
|
||||
lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
|
||||
|
||||
with gr.Tab("Train"):
|
||||
train_elems = create_train_tab(top_elems, runner)
|
||||
engine.manager.all_elems["train"] = create_train_tab(engine)
|
||||
|
||||
with gr.Tab("Evaluate"):
|
||||
eval_elems = create_eval_tab(top_elems, runner)
|
||||
engine.manager.all_elems["eval"] = create_eval_tab(engine)
|
||||
|
||||
with gr.Tab("Chat"):
|
||||
infer_elems = create_infer_tab(top_elems)
|
||||
engine.manager.all_elems["infer"] = create_infer_tab(engine)
|
||||
|
||||
with gr.Tab("Export"):
|
||||
export_elems = create_export_tab(top_elems)
|
||||
engine.manager.all_elems["export"] = create_export_tab(engine)
|
||||
|
||||
elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems]
|
||||
manager = Manager(elem_list)
|
||||
|
||||
demo.load(
|
||||
manager.gen_label,
|
||||
[top_elems["lang"]],
|
||||
[elem for elems in elem_list for elem in elems.values()],
|
||||
)
|
||||
|
||||
top_elems["lang"].change(
|
||||
manager.gen_label,
|
||||
[top_elems["lang"]],
|
||||
[elem for elems in elem_list for elem in elems.values()],
|
||||
queue=False
|
||||
)
|
||||
demo.load(engine.resume, outputs=engine.manager.list_elems())
|
||||
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
|
||||
lang.input(save_config, inputs=[lang], queue=False)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
def create_web_demo() -> gr.Blocks:
|
||||
chat_model = WebChatModel(lazy_init=False)
|
||||
engine = Engine(pure_chat=True)
|
||||
|
||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="en")
|
||||
lang = gr.Dropdown(choices=["en", "zh"])
|
||||
engine.manager.all_elems["top"] = dict(lang=lang)
|
||||
|
||||
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
|
||||
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
|
||||
engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
|
||||
|
||||
manager = Manager([{"lang": lang}, chat_elems])
|
||||
|
||||
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
||||
|
||||
lang.select(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False)
|
||||
demo.load(engine.resume, outputs=engine.manager.list_elems())
|
||||
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
|
||||
lang.input(save_config, inputs=[lang], queue=False)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
@@ -59,12 +59,12 @@ LOCALES = {
|
||||
},
|
||||
"quantization_bit": {
|
||||
"en": {
|
||||
"label": "Quantization bit (optional)",
|
||||
"info": "Enable 4/8-bit model quantization."
|
||||
"label": "Quantization bit",
|
||||
"info": "Enable 4/8-bit model quantization (QLoRA)."
|
||||
},
|
||||
"zh": {
|
||||
"label": "量化等级(非必填)",
|
||||
"info": "启用 4/8 比特模型量化。"
|
||||
"label": "量化等级",
|
||||
"info": "启用 4/8 比特模型量化(QLoRA)。"
|
||||
}
|
||||
},
|
||||
"template": {
|
||||
@@ -87,6 +87,38 @@ LOCALES = {
|
||||
"info": "默认使用的系统提示词"
|
||||
}
|
||||
},
|
||||
"llama_tab": {
|
||||
"en": {
|
||||
"label": "Model configurations (LLaMA only)"
|
||||
},
|
||||
"zh": {
|
||||
"label": "模型设置(仅LLaMA)"
|
||||
}
|
||||
},
|
||||
"flash_attn": {
|
||||
"en": {
|
||||
"label": "Use FlashAttention-2"
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 FlashAttention-2"
|
||||
}
|
||||
},
|
||||
"shift_attn": {
|
||||
"en": {
|
||||
"label": "Use shift short attention (S^2-Attn)"
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 shift short attention (S^2-Attn)"
|
||||
}
|
||||
},
|
||||
"rope_scaling": {
|
||||
"en": {
|
||||
"label": "RoPE scaling"
|
||||
},
|
||||
"zh": {
|
||||
"label": "RoPE 插值方法"
|
||||
}
|
||||
},
|
||||
"training_stage": {
|
||||
"en": {
|
||||
"label": "Stage",
|
||||
@@ -131,12 +163,28 @@ LOCALES = {
|
||||
"label": "数量"
|
||||
}
|
||||
},
|
||||
"preview_samples": {
|
||||
"page_index": {
|
||||
"en": {
|
||||
"label": "Samples"
|
||||
"label": "Page"
|
||||
},
|
||||
"zh": {
|
||||
"label": "样例"
|
||||
"label": "页数"
|
||||
}
|
||||
},
|
||||
"prev_btn": {
|
||||
"en": {
|
||||
"value": "Prev"
|
||||
},
|
||||
"zh": {
|
||||
"value": "上一页"
|
||||
}
|
||||
},
|
||||
"next_btn": {
|
||||
"en": {
|
||||
"value": "Next"
|
||||
},
|
||||
"zh": {
|
||||
"value": "下一页"
|
||||
}
|
||||
},
|
||||
"close_btn": {
|
||||
@@ -147,24 +195,22 @@ LOCALES = {
|
||||
"value": "关闭"
|
||||
}
|
||||
},
|
||||
"max_source_length": {
|
||||
"preview_samples": {
|
||||
"en": {
|
||||
"label": "Max source length",
|
||||
"info": "Max tokens in source sequence."
|
||||
"label": "Samples"
|
||||
},
|
||||
"zh": {
|
||||
"label": "输入序列最大长度",
|
||||
"info": "输入序列分词后的最大长度。"
|
||||
"label": "样例"
|
||||
}
|
||||
},
|
||||
"max_target_length": {
|
||||
"cutoff_len": {
|
||||
"en": {
|
||||
"label": "Max target length",
|
||||
"info": "Max tokens in target sequence."
|
||||
"label": "Cutoff length",
|
||||
"info": "Max tokens in input sequence."
|
||||
},
|
||||
"zh": {
|
||||
"label": "输出序列最大长度",
|
||||
"info": "输出序列分词后的最大长度。"
|
||||
"label": "截断长度",
|
||||
"info": "输入序列分词后的最大长度。"
|
||||
}
|
||||
},
|
||||
"learning_rate": {
|
||||
@@ -197,6 +243,16 @@ LOCALES = {
|
||||
"info": "每个数据集最多使用的样本数。"
|
||||
}
|
||||
},
|
||||
"compute_type": {
|
||||
"en": {
|
||||
"label": "Compute type",
|
||||
"info": "Whether to use fp16 or bf16 mixed precision training."
|
||||
},
|
||||
"zh": {
|
||||
"label": "计算类型",
|
||||
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
||||
}
|
||||
},
|
||||
"batch_size": {
|
||||
"en": {
|
||||
"label": "Batch size",
|
||||
@@ -277,14 +333,34 @@ LOCALES = {
|
||||
"info": "学习率预热采用的步数。"
|
||||
}
|
||||
},
|
||||
"compute_type": {
|
||||
"neft_alpha": {
|
||||
"en": {
|
||||
"label": "Compute type",
|
||||
"info": "Whether to use fp16 or bf16 mixed precision training."
|
||||
"label": "NEFTune Alpha",
|
||||
"info": "Magnitude of noise adding to embedding vectors."
|
||||
},
|
||||
"zh": {
|
||||
"label": "计算类型",
|
||||
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
||||
"label": "NEFTune 噪声参数",
|
||||
"info": "嵌入向量所添加的噪声大小。"
|
||||
}
|
||||
},
|
||||
"train_on_prompt": {
|
||||
"en": {
|
||||
"label": "Train on prompt",
|
||||
"info": "Compute loss on the prompt tokens in supervised fine-tuning."
|
||||
},
|
||||
"zh": {
|
||||
"label": "计算输入损失",
|
||||
"info": "在监督微调时候计算输入序列的损失。"
|
||||
}
|
||||
},
|
||||
"upcast_layernorm": {
|
||||
"en": {
|
||||
"label": "Upcast LayerNorm",
|
||||
"info": "Upcast weights of layernorm in float32."
|
||||
},
|
||||
"zh": {
|
||||
"label": "缩放归一化层",
|
||||
"info": "将归一化层权重缩放至 32 位浮点数。"
|
||||
}
|
||||
},
|
||||
"lora_tab": {
|
||||
@@ -318,11 +394,21 @@ LOCALES = {
|
||||
"lora_target": {
|
||||
"en": {
|
||||
"label": "LoRA modules (optional)",
|
||||
"info": "The name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
|
||||
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 作用层(非必填)",
|
||||
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
|
||||
"label": "LoRA 作用模块(非必填)",
|
||||
"info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"
|
||||
}
|
||||
},
|
||||
"additional_target": {
|
||||
"en": {
|
||||
"label": "Additional modules (optional)",
|
||||
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules."
|
||||
},
|
||||
"zh": {
|
||||
"label": "附加模块(非必填)",
|
||||
"info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"
|
||||
}
|
||||
},
|
||||
"resume_lora_training": {
|
||||
@@ -356,11 +442,11 @@ LOCALES = {
|
||||
"reward_model": {
|
||||
"en": {
|
||||
"label": "Reward model",
|
||||
"info": "Checkpoint of the reward model for PPO training."
|
||||
"info": "Checkpoint of the reward model for PPO training. (Needs to refresh checkpoints)"
|
||||
},
|
||||
"zh": {
|
||||
"label": "奖励模型",
|
||||
"info": "PPO 训练中奖励模型的断点路径。"
|
||||
"info": "PPO 训练中奖励模型的断点路径。(需要刷新断点)"
|
||||
}
|
||||
},
|
||||
"cmd_preview_btn": {
|
||||
@@ -509,7 +595,7 @@ LOCALES = {
|
||||
"label": "温度系数"
|
||||
}
|
||||
},
|
||||
"save_dir": {
|
||||
"export_dir": {
|
||||
"en": {
|
||||
"label": "Export dir",
|
||||
"info": "Directory to save exported model."
|
||||
@@ -565,7 +651,7 @@ ALERTS = {
|
||||
"en": "Please select a checkpoint.",
|
||||
"zh": "请选择断点。"
|
||||
},
|
||||
"err_no_save_dir": {
|
||||
"err_no_export_dir": {
|
||||
"en": "Please provide export dir.",
|
||||
"zh": "请填写导出目录"
|
||||
},
|
||||
|
||||
@@ -1,46 +1,35 @@
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
from typing import Any, Dict, List
|
||||
from typing import TYPE_CHECKING, Dict, List, Set
|
||||
|
||||
from llmtuner.webui.common import get_model_path, list_dataset, load_config
|
||||
from llmtuner.webui.locales import LOCALES
|
||||
from llmtuner.webui.utils import get_time
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
class Manager:
|
||||
|
||||
def __init__(self, elem_list: List[Dict[str, Component]]):
|
||||
self.elem_list = elem_list
|
||||
def __init__(self) -> None:
|
||||
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
|
||||
|
||||
def gen_refresh(self, lang: str) -> Dict[str, Any]:
|
||||
refresh_dict = {
|
||||
"dataset": {"choices": list_dataset()["choices"]},
|
||||
"output_dir": {"value": get_time()}
|
||||
def get_elem_by_name(self, name: str) -> "Component":
|
||||
r"""
|
||||
Example: top.lang, train.dataset
|
||||
"""
|
||||
tab_name, elem_name = name.split(".")
|
||||
return self.all_elems[tab_name][elem_name]
|
||||
|
||||
def get_base_elems(self) -> Set["Component"]:
|
||||
return {
|
||||
self.all_elems["top"]["lang"],
|
||||
self.all_elems["top"]["model_name"],
|
||||
self.all_elems["top"]["model_path"],
|
||||
self.all_elems["top"]["checkpoints"],
|
||||
self.all_elems["top"]["finetuning_type"],
|
||||
self.all_elems["top"]["quantization_bit"],
|
||||
self.all_elems["top"]["template"],
|
||||
self.all_elems["top"]["system_prompt"],
|
||||
self.all_elems["top"]["flash_attn"],
|
||||
self.all_elems["top"]["shift_attn"],
|
||||
self.all_elems["top"]["rope_scaling"]
|
||||
}
|
||||
|
||||
user_config = load_config()
|
||||
if not lang:
|
||||
if user_config.get("lang", None):
|
||||
lang = user_config["lang"]
|
||||
else:
|
||||
lang = "en"
|
||||
|
||||
refresh_dict["lang"] = {"value": lang}
|
||||
|
||||
if user_config.get("last_model", None):
|
||||
refresh_dict["model_name"] = {"value": user_config["last_model"]}
|
||||
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
|
||||
|
||||
return refresh_dict
|
||||
|
||||
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
|
||||
update_dict = {}
|
||||
refresh_dict = self.gen_refresh(lang)
|
||||
|
||||
for elems in self.elem_list:
|
||||
for name, component in elems.items():
|
||||
update_dict[component] = gr.update(
|
||||
**LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {})
|
||||
)
|
||||
|
||||
return update_dict
|
||||
def list_elems(self) -> List["Component"]:
|
||||
return [elem for elems in self.all_elems.values() for elem in elems.values()]
|
||||
|
||||
@@ -1,46 +1,65 @@
|
||||
import gradio as gr
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
import gradio as gr
|
||||
from threading import Thread
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
|
||||
|
||||
import transformers
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES
|
||||
from llmtuner.extras.constants import TRAINING_STAGES
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import run_exp
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, load_config
|
||||
from llmtuner.webui.common import get_module, get_save_dir, load_config
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.webui.manager import Manager
|
||||
|
||||
|
||||
class Runner:
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, manager: "Manager") -> None:
|
||||
self.manager = manager
|
||||
""" Resume """
|
||||
self.thread: "Thread" = None
|
||||
self.do_train = True
|
||||
self.running_data: Dict["Component", Any] = None
|
||||
self.monitor_inputs: Dict[str, str] = None
|
||||
""" State """
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
""" Handler """
|
||||
self.logger_handler = LoggerHandler()
|
||||
self.logger_handler.setLevel(logging.INFO)
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
def set_abort(self):
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
return self.thread is not None
|
||||
|
||||
def set_abort(self) -> None:
|
||||
self.aborted = True
|
||||
self.running = False
|
||||
|
||||
def _initialize(
|
||||
self, lang: str, model_name: str, dataset: List[str]
|
||||
) -> str:
|
||||
def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
||||
|
||||
if self.running:
|
||||
return ALERTS["err_conflict"][lang]
|
||||
|
||||
if not model_name:
|
||||
return ALERTS["err_no_model"][lang]
|
||||
|
||||
if not get_model_path(model_name):
|
||||
if not model_path:
|
||||
return ALERTS["err_no_path"][lang]
|
||||
|
||||
if len(dataset) == 0:
|
||||
@@ -51,9 +70,8 @@ class Runner:
|
||||
self.trainer_callback = LogCallback(self)
|
||||
return ""
|
||||
|
||||
def _finalize(
|
||||
self, lang: str, finish_info: str
|
||||
) -> str:
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
self.thread = None
|
||||
self.running = False
|
||||
torch_gc()
|
||||
if self.aborted:
|
||||
@@ -61,222 +79,176 @@ class Runner:
|
||||
else:
|
||||
return finish_info
|
||||
|
||||
def _parse_train_args(
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
system_prompt: str,
|
||||
training_stage: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_source_length: int,
|
||||
max_target_length: int,
|
||||
learning_rate: str,
|
||||
num_train_epochs: str,
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
gradient_accumulation_steps: int,
|
||||
lr_scheduler_type: str,
|
||||
max_grad_norm: str,
|
||||
val_size: float,
|
||||
logging_steps: int,
|
||||
save_steps: int,
|
||||
warmup_steps: int,
|
||||
compute_type: str,
|
||||
lora_rank: int,
|
||||
lora_dropout: float,
|
||||
lora_target: str,
|
||||
resume_lora_training: bool,
|
||||
dpo_beta: float,
|
||||
reward_model: str,
|
||||
output_dir: str
|
||||
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
|
||||
)
|
||||
def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
output_dir = get_save_dir(model_name, finetuning_type, output_dir)
|
||||
|
||||
user_config = load_config()
|
||||
cache_dir = user_config.get("cache_dir", None)
|
||||
|
||||
args = dict(
|
||||
stage=TRAINING_STAGES[training_stage],
|
||||
model_name_or_path=get_model_path(model_name),
|
||||
stage=TRAINING_STAGES[get("train.training_stage")],
|
||||
model_name_or_path=get("top.model_path"),
|
||||
do_train=True,
|
||||
overwrite_cache=False,
|
||||
cache_dir=cache_dir,
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
|
||||
template=template,
|
||||
system_prompt=system_prompt,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
learning_rate=float(learning_rate),
|
||||
num_train_epochs=float(num_train_epochs),
|
||||
max_samples=int(max_samples),
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
max_grad_norm=float(max_grad_norm),
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||
resume_lora_training=(
|
||||
False if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] else resume_lora_training
|
||||
),
|
||||
output_dir=output_dir
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
system_prompt=get("top.system_prompt"),
|
||||
flash_attn=get("top.flash_attn"),
|
||||
shift_attn=get("top.shift_attn"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
dataset_dir=get("train.dataset_dir"),
|
||||
dataset=",".join(get("train.dataset")),
|
||||
cutoff_len=get("train.cutoff_len"),
|
||||
learning_rate=float(get("train.learning_rate")),
|
||||
num_train_epochs=float(get("train.num_train_epochs")),
|
||||
max_samples=int(get("train.max_samples")),
|
||||
per_device_train_batch_size=get("train.batch_size"),
|
||||
gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
|
||||
lr_scheduler_type=get("train.lr_scheduler_type"),
|
||||
max_grad_norm=float(get("train.max_grad_norm")),
|
||||
logging_steps=get("train.logging_steps"),
|
||||
save_steps=get("train.save_steps"),
|
||||
warmup_steps=get("train.warmup_steps"),
|
||||
neft_alpha=get("train.neft_alpha"),
|
||||
train_on_prompt=get("train.train_on_prompt"),
|
||||
upcast_layernorm=get("train.upcast_layernorm"),
|
||||
lora_rank=get("train.lora_rank"),
|
||||
lora_dropout=get("train.lora_dropout"),
|
||||
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
|
||||
additional_target=get("train.additional_target") if get("train.additional_target") else None,
|
||||
resume_lora_training=get("train.resume_lora_training"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
|
||||
)
|
||||
args[compute_type] = True
|
||||
args[get("train.compute_type")] = True
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
||||
args["resume_lora_training"] = (args["quantization_bit"] is not None)
|
||||
|
||||
if args["quantization_bit"] is not None:
|
||||
args["upcast_layernorm"] = True
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = reward_model
|
||||
val_size = 0
|
||||
args["reward_model"] = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.reward_model"))
|
||||
|
||||
if args["stage"] == "dpo":
|
||||
args["dpo_beta"] = dpo_beta
|
||||
args["dpo_beta"] = get("train.dpo_beta")
|
||||
|
||||
if val_size > 1e-6:
|
||||
args["val_size"] = val_size
|
||||
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
||||
args["val_size"] = get("train.val_size")
|
||||
args["evaluation_strategy"] = "steps"
|
||||
args["eval_steps"] = save_steps
|
||||
args["eval_steps"] = get("train.save_steps")
|
||||
args["load_best_model_at_end"] = True
|
||||
|
||||
return lang, model_name, dataset, output_dir, args
|
||||
return args
|
||||
|
||||
def _parse_eval_args(
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
system_prompt: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_source_length: int,
|
||||
max_target_length: int,
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
predict: bool
|
||||
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
|
||||
def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
output_dir = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
||||
)
|
||||
output_dir = get_save_dir(model_name, finetuning_type, "eval_" + "_".join(checkpoints))
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
output_dir = get_save_dir(model_name, finetuning_type, "eval_base")
|
||||
|
||||
user_config = load_config()
|
||||
cache_dir = user_config.get("cache_dir", None)
|
||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base")
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=get_model_path(model_name),
|
||||
model_name_or_path=get("top.model_path"),
|
||||
do_eval=True,
|
||||
overwrite_cache=False,
|
||||
predict_with_generate=True,
|
||||
cache_dir=cache_dir,
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
|
||||
template=template,
|
||||
system_prompt=system_prompt,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
max_samples=int(max_samples),
|
||||
per_device_eval_batch_size=batch_size,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
system_prompt=get("top.system_prompt"),
|
||||
flash_attn=get("top.flash_attn"),
|
||||
shift_attn=get("top.shift_attn"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
dataset_dir=get("eval.dataset_dir"),
|
||||
dataset=",".join(get("eval.dataset")),
|
||||
cutoff_len=get("eval.cutoff_len"),
|
||||
max_samples=int(get("eval.max_samples")),
|
||||
per_device_eval_batch_size=get("eval.batch_size"),
|
||||
max_new_tokens=get("eval.max_new_tokens"),
|
||||
top_p=get("eval.top_p"),
|
||||
temperature=get("eval.temperature"),
|
||||
output_dir=output_dir
|
||||
)
|
||||
|
||||
if predict:
|
||||
if get("eval.predict"):
|
||||
args.pop("do_eval", None)
|
||||
args["do_predict"] = True
|
||||
|
||||
return lang, model_name, dataset, output_dir, args
|
||||
return args
|
||||
|
||||
def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
lang, model_name, dataset, _, args = self._parse_train_args(*args)
|
||||
error = self._initialize(lang, model_name, dataset)
|
||||
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
error = self._initialize(data, do_train)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error, gr.update(visible=False)
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
yield gen_cmd(args), gr.update(visible=False)
|
||||
|
||||
def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
lang, model_name, dataset, _, args = self._parse_eval_args(*args)
|
||||
error = self._initialize(lang, model_name, dataset)
|
||||
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
error = self._initialize(data, do_train)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error, gr.update(visible=False)
|
||||
else:
|
||||
yield gen_cmd(args), gr.update(visible=False)
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
self.running = True
|
||||
self.do_train, self.running_data = do_train, data
|
||||
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
|
||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
self.thread.start()
|
||||
yield from self.monitor()
|
||||
|
||||
def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
lang, model_name, dataset, output_dir, args = self._parse_train_args(*args)
|
||||
error = self._initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error, gr.update(visible=False)
|
||||
return
|
||||
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
yield from self._preview(data, do_train=True)
|
||||
|
||||
self.running = True
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
|
||||
thread.start()
|
||||
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
yield from self._preview(data, do_train=False)
|
||||
|
||||
while thread.is_alive():
|
||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
yield from self._launch(data, do_train=True)
|
||||
|
||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
||||
else:
|
||||
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
|
||||
|
||||
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self._finalize(lang, finish_info), gr.update(visible=False)
|
||||
|
||||
def run_eval(self, *args) -> Generator[str, None, None]:
|
||||
lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args)
|
||||
error = self._initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error, gr.update(visible=False)
|
||||
return
|
||||
|
||||
self.running = True
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
|
||||
thread.start()
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
else:
|
||||
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
|
||||
|
||||
if os.path.exists(os.path.join(output_dir, "all_results.json")):
|
||||
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
if os.path.exists(os.path.join(output_dir, "all_results.json")):
|
||||
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self._finalize(lang, finish_info), gr.update(visible=False)
|
||||
|
||||
@@ -3,13 +3,11 @@ import json
|
||||
import gradio as gr
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from llmtuner.extras.ploting import smooth
|
||||
from llmtuner.tuner import export_model
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.common import get_save_dir
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
@@ -33,37 +31,6 @@ def get_time() -> str:
|
||||
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
if (
|
||||
len(dataset) > 0
|
||||
and "file_name" in dataset_info[dataset[0]]
|
||||
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
|
||||
):
|
||||
return gr.update(interactive=True)
|
||||
else:
|
||||
return gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_preview(
|
||||
dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2
|
||||
) -> Tuple[int, list, Dict[str, Any]]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
data_file: str = dataset_info[dataset[0]]["file_name"]
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
if data_file.endswith(".json"):
|
||||
data = json.load(f)
|
||||
elif data_file.endswith(".jsonl"):
|
||||
data = [json.loads(line) for line in f]
|
||||
else:
|
||||
data = [line for line in f]
|
||||
return len(data), data[start:end], gr.update(visible=True)
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
if finetuning_type != "lora":
|
||||
return gr.update(value="None", interactive=False)
|
||||
@@ -72,8 +39,8 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
if args.get("do_train", None):
|
||||
args["plot_loss"] = True
|
||||
args.pop("disable_tqdm", None)
|
||||
args["plot_loss"] = args.get("do_train", None)
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py "]
|
||||
for k, v in args.items():
|
||||
if v is not None and v != "":
|
||||
@@ -90,9 +57,11 @@ def get_eval_results(path: os.PathLike) -> str:
|
||||
|
||||
|
||||
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
|
||||
if not base_model:
|
||||
return
|
||||
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
|
||||
if not os.path.isfile(log_file):
|
||||
return None
|
||||
return
|
||||
|
||||
plt.close("all")
|
||||
fig = plt.figure()
|
||||
@@ -114,46 +83,3 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
||||
ax.set_xlabel("step")
|
||||
ax.set_ylabel("loss")
|
||||
return fig
|
||||
|
||||
|
||||
def save_model(
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
template: str,
|
||||
max_shard_size: int,
|
||||
save_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
if not model_name:
|
||||
yield ALERTS["err_no_model"][lang]
|
||||
return
|
||||
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
if not model_name_or_path:
|
||||
yield ALERTS["err_no_path"][lang]
|
||||
return
|
||||
|
||||
if not checkpoints:
|
||||
yield ALERTS["err_no_checkpoint"][lang]
|
||||
return
|
||||
|
||||
checkpoint_dir = ",".join(
|
||||
[get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
|
||||
)
|
||||
|
||||
if not save_dir:
|
||||
yield ALERTS["err_no_save_dir"][lang]
|
||||
return
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
output_dir=save_dir
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
export_model(args, max_shard_size="{}GB".format(max_shard_size))
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
44
tests/cal_flops.py
Normal file
44
tests/cal_flops.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# coding=utf-8
|
||||
# Calculates the flops of pre-trained models.
|
||||
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from typing import Optional
|
||||
from deepspeed.accelerator import get_accelerator # type: ignore
|
||||
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
|
||||
|
||||
from llmtuner import ChatModel
|
||||
|
||||
|
||||
def calculate(
|
||||
model_name_or_path: str,
|
||||
batch_size: Optional[int] = 1,
|
||||
seq_length: Optional[int] = 256,
|
||||
flash_attn: Optional[bool] = False
|
||||
):
|
||||
with get_accelerator().device(0):
|
||||
chat_model = ChatModel(dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
template="vanilla",
|
||||
flash_attn=flash_attn
|
||||
))
|
||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
||||
input_dict = {
|
||||
"input_ids": fake_input,
|
||||
"labels": fake_input.clone()
|
||||
}
|
||||
flops, macs, params = get_model_profile(
|
||||
chat_model.model,
|
||||
kwargs=input_dict,
|
||||
print_profile=True,
|
||||
detailed=True
|
||||
)
|
||||
print("FLOPs:", flops)
|
||||
print("MACs:", macs)
|
||||
print("Params:", params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(calculate)
|
||||
@@ -1,133 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Evaluates fine-tuned models automatically.
|
||||
# Usage: python evaluate_zh.py --evalset ceval/ceval-exam:law --split dev --output_file result.json
|
||||
# --api_base http://localhost:8000/v1 --task_type choice --n_samples 100
|
||||
# dataset format: question (string), A (string), B (string), C (string), D (string), answer (Literal["A", "B", "C", "D"])
|
||||
|
||||
|
||||
import os
|
||||
import fire
|
||||
import json
|
||||
import openai
|
||||
from tqdm import tqdm
|
||||
from typing import Literal, Optional
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def format_example_choice(examples):
|
||||
model_inputs = {"query": [], "label": []}
|
||||
task_template = "请从ABCD四个选项中选出正确的选项,仅输出选项序号。\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\n答案:"
|
||||
for i in range(len(examples["id"])):
|
||||
query = task_template.format(
|
||||
question=examples["question"][i],
|
||||
A=examples["A"][i],
|
||||
B=examples["B"][i],
|
||||
C=examples["C"][i],
|
||||
D=examples["D"][i]
|
||||
)
|
||||
label = examples["answer"][i]
|
||||
model_inputs["query"].append(query)
|
||||
model_inputs["label"].append(label)
|
||||
return model_inputs
|
||||
|
||||
|
||||
def format_example_cloze(examples):
|
||||
model_inputs = {"query": [], "label": []}
|
||||
task_template = "请选择正确的答案填空,仅输出正确的选项。\n{question}\n选项:{A}\n{B}\n{C}\n{D}\n答案:"
|
||||
for i in range(len(examples["id"])):
|
||||
query = task_template.format(
|
||||
question=examples["question"][i],
|
||||
A=examples["A"][i],
|
||||
B=examples["B"][i],
|
||||
C=examples["C"][i],
|
||||
D=examples["D"][i]
|
||||
)
|
||||
label = examples[examples["answer"][i]][i]
|
||||
model_inputs["query"].append(query)
|
||||
model_inputs["label"].append(label)
|
||||
return model_inputs
|
||||
|
||||
|
||||
def format_example_openqa(examples):
|
||||
model_inputs = {"query": [], "label": []}
|
||||
task_template = "回答以下问题:{question}\n答案:"
|
||||
for i in range(len(examples["id"])):
|
||||
query = task_template.format(question=examples["question"][i])
|
||||
label = examples[examples["answer"][i]][i]
|
||||
model_inputs["query"].append(query)
|
||||
model_inputs["label"].append(label)
|
||||
return model_inputs
|
||||
|
||||
|
||||
TASK_DICT = {
|
||||
"choice": format_example_choice,
|
||||
"cloze": format_example_cloze,
|
||||
"openqa": format_example_openqa
|
||||
}
|
||||
|
||||
|
||||
EXT2TYPE = {
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json"
|
||||
}
|
||||
|
||||
|
||||
def evaluate(
|
||||
evalset: str,
|
||||
api_base: str,
|
||||
output_file: str,
|
||||
split: Optional[str] = "val",
|
||||
task_type: Optional[Literal["choice", "cloze", "openqa"]] = "choice",
|
||||
n_samples: Optional[int] = 20
|
||||
):
|
||||
|
||||
openai.api_base = api_base
|
||||
openai.api_key = "none"
|
||||
|
||||
if os.path.isfile(evalset):
|
||||
dataset = load_dataset(EXT2TYPE[evalset.split(".")[-1]], data_files=evalset)["train"]
|
||||
elif ":" in evalset:
|
||||
evalset, subset = evalset.split(":")
|
||||
dataset = load_dataset(evalset, subset, split=split)
|
||||
else:
|
||||
dataset = load_dataset(evalset, split=split)
|
||||
|
||||
n_samples = min(len(dataset), n_samples)
|
||||
|
||||
dataset = dataset.map(TASK_DICT[task_type], batched=True)
|
||||
dataset = dataset.select(range(n_samples))
|
||||
|
||||
n_correct = 0
|
||||
predictions = []
|
||||
for example in tqdm(dataset):
|
||||
query, label = example["query"], example["label"]
|
||||
predict = openai.ChatCompletion.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": query}],
|
||||
temperature=0.01,
|
||||
top_p=0.01,
|
||||
max_new_tokens=20
|
||||
).choices[0].message.content
|
||||
|
||||
if task_type == "choice" and predict[0].lower() == label[0].lower():
|
||||
n_correct += 1
|
||||
if task_type == "cloze" and label in [predict[:len(label)], predict[-len(label):]]:
|
||||
n_correct += 1
|
||||
if task_type == "openqa" and label in predict:
|
||||
n_correct += 1
|
||||
|
||||
predictions.append({
|
||||
"query": query,
|
||||
"label": label,
|
||||
"predict": predict
|
||||
})
|
||||
|
||||
print("Result: {}/{}\nAccuracy: {:.2f}%".format(n_correct, n_samples, n_correct / n_samples * 100))
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(predictions, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(evaluate)
|
||||
@@ -1,6 +1,6 @@
|
||||
# coding=utf-8
|
||||
# Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
||||
# Usage: python llamafy_baichuan2.py --llama2_json llama2.index.json --input_dir input --output_dir output
|
||||
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output --shard_size 10GB
|
||||
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
|
||||
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
||||
|
||||
@@ -9,56 +9,77 @@ import fire
|
||||
import json
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
SHARD_A = "pytorch_model-00001-of-00002.bin"
|
||||
SHARD_B = "pytorch_model-00002-of-00002.bin"
|
||||
CONFIG_NAME = "config.json"
|
||||
|
||||
|
||||
def llamafy_baichuan2(
|
||||
llama2_json: str,
|
||||
def save_weight(
|
||||
input_dir: str,
|
||||
output_dir: str
|
||||
output_dir: str,
|
||||
shard_size: str
|
||||
):
|
||||
baichuan2_state_dict = OrderedDict()
|
||||
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
for filepath in os.listdir(input_dir):
|
||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
||||
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
||||
baichuan2_state_dict.update(shard_weight)
|
||||
|
||||
llama2_state_dict = OrderedDict()
|
||||
total_size = 0
|
||||
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
for key, value in baichuan2_state_dict.items():
|
||||
total_size += 2 * value.numel() # half precision
|
||||
if "W_pack" in key:
|
||||
llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:4096, :]
|
||||
llama2_state_dict[key.replace("W_pack", "k_proj")] = value[4096:2*4096, :]
|
||||
llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2*4096:, :]
|
||||
proj_size = value.size(0) // 3
|
||||
llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
|
||||
llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size:2*proj_size, :]
|
||||
llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2*proj_size:, :]
|
||||
elif "lm_head" in key:
|
||||
llama2_state_dict[key] = torch.nn.functional.normalize(value)
|
||||
else:
|
||||
llama2_state_dict[key] = value
|
||||
|
||||
with open(os.path.join(input_dir, llama2_json), "r", encoding="utf-8") as f:
|
||||
llama2_index = json.load(f)
|
||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME)
|
||||
for shard_file, shard in shards.items():
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
||||
else:
|
||||
with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
|
||||
merged_index = OrderedDict()
|
||||
merged_index["metadata"] = {"total_size": total_size}
|
||||
merged_index["weight_map"] = llama2_index["weight_map"]
|
||||
|
||||
state_dict_a, state_dict_b = OrderedDict(), OrderedDict()
|
||||
for key, value in llama2_state_dict.items():
|
||||
if merged_index["weight_map"][key] == SHARD_A:
|
||||
state_dict_a[key] = value
|
||||
else:
|
||||
state_dict_b[key] = value
|
||||
def save_config(
|
||||
input_dir: str,
|
||||
output_dir: str
|
||||
):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
torch.save(state_dict_a, os.path.join(output_dir, SHARD_A))
|
||||
torch.save(state_dict_b, os.path.join(output_dir, SHARD_B))
|
||||
with open(os.path.join(output_dir, "pytorch_model.bin.index.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(merged_index, f, indent=2)
|
||||
print("Completed!")
|
||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||
llama2_config_dict.pop("auto_map", None)
|
||||
llama2_config_dict.pop("tokenizer_class", None)
|
||||
llama2_config_dict["model_type"] = "llama"
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
|
||||
|
||||
def llamafy_baichuan2(
|
||||
input_dir: str,
|
||||
output_dir: str,
|
||||
shard_size: str
|
||||
):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=False)
|
||||
except Exception as e:
|
||||
raise print("Output dir already exists", e)
|
||||
|
||||
save_weight(input_dir, output_dir, shard_size)
|
||||
save_config(input_dir, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
135
tests/llamafy_qwen.py
Normal file
135
tests/llamafy_qwen.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# coding=utf-8
|
||||
# Converts the Qwen models in the same format as LLaMA2.
|
||||
# Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB
|
||||
|
||||
import os
|
||||
import fire
|
||||
import json
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from safetensors import safe_open
|
||||
from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from transformers.utils import check_min_version
|
||||
from typing import Any, Dict
|
||||
|
||||
try:
|
||||
check_min_version("4.34.0")
|
||||
except:
|
||||
raise ValueError("Please upgrade `transformers` to 4.34.0")
|
||||
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
|
||||
|
||||
def save_weight(
|
||||
input_dir: str,
|
||||
output_dir: str,
|
||||
shard_size: str
|
||||
) -> str:
|
||||
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
for filepath in os.listdir(input_dir):
|
||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
|
||||
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
qwen_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
torch_dtype = None
|
||||
for key, value in qwen_state_dict.items():
|
||||
if torch_dtype is None:
|
||||
torch_dtype = value.dtype
|
||||
if "wte" in key:
|
||||
llama2_state_dict["model.embed_tokens.weight"] = value
|
||||
elif "ln_f" in key:
|
||||
llama2_state_dict["model.norm.weight"] = value
|
||||
else:
|
||||
key = key.replace("transformer.h", "model.layers")
|
||||
if "attn.c_attn" in key:
|
||||
proj_size = value.size(0) // 3
|
||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
|
||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[proj_size:2*proj_size, ...]
|
||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2*proj_size:, ...]
|
||||
elif "attn.c_proj" in key:
|
||||
llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
|
||||
llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = (
|
||||
torch.zeros_like(value[:, 0]).squeeze()
|
||||
)
|
||||
elif "ln_1" in key:
|
||||
llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value
|
||||
elif "ln_2" in key:
|
||||
llama2_state_dict[key.replace("ln_2", "post_attention_layernorm")] = value
|
||||
elif "mlp.w1" in key:
|
||||
llama2_state_dict[key.replace("mlp.w1", "mlp.up_proj")] = value
|
||||
elif "mlp.w2" in key:
|
||||
llama2_state_dict[key.replace("mlp.w2", "mlp.gate_proj")] = value
|
||||
elif "mlp.c_proj" in key:
|
||||
llama2_state_dict[key.replace("mlp.c_proj", "mlp.down_proj")] = value
|
||||
elif "lm_head" in key:
|
||||
llama2_state_dict[key] = value
|
||||
else:
|
||||
raise KeyError("Unable to process key {}".format(key))
|
||||
|
||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME)
|
||||
for shard_file, shard in shards.items():
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
||||
else:
|
||||
with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
|
||||
return str(torch_dtype).replace("torch.", "")
|
||||
|
||||
|
||||
def save_config(
|
||||
input_dir: str,
|
||||
output_dir: str,
|
||||
torch_dtype: str
|
||||
):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
qwen_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||
llama2_config_dict["hidden_act"] = "silu"
|
||||
llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"]
|
||||
llama2_config_dict["initializer_range"] = qwen_config_dict["initializer_range"]
|
||||
llama2_config_dict["intermediate_size"] = qwen_config_dict["intermediate_size"] // 2
|
||||
llama2_config_dict["max_position_embeddings"] = qwen_config_dict["max_position_embeddings"]
|
||||
llama2_config_dict["model_type"] = "llama"
|
||||
llama2_config_dict["num_attention_heads"] = qwen_config_dict["num_attention_heads"]
|
||||
llama2_config_dict["num_hidden_layers"] = qwen_config_dict["num_hidden_layers"]
|
||||
llama2_config_dict["num_key_value_heads"] = qwen_config_dict["hidden_size"] // qwen_config_dict["kv_channels"]
|
||||
llama2_config_dict["pretraining_tp"] = 1
|
||||
llama2_config_dict["rms_norm_eps"] = qwen_config_dict["layer_norm_epsilon"]
|
||||
llama2_config_dict["rope_scaling"] = None
|
||||
llama2_config_dict["tie_word_embeddings"] = qwen_config_dict["tie_word_embeddings"]
|
||||
llama2_config_dict["torch_dtype"] = torch_dtype
|
||||
llama2_config_dict["transformers_version"] = "4.34.0"
|
||||
llama2_config_dict["use_cache"] = True
|
||||
llama2_config_dict["vocab_size"] = qwen_config_dict["vocab_size"]
|
||||
llama2_config_dict["attention_bias"] = True
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
|
||||
|
||||
def llamafy_qwen(
|
||||
input_dir: str,
|
||||
output_dir: str,
|
||||
shard_size: str
|
||||
):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=False)
|
||||
except Exception as e:
|
||||
raise print("Output dir already exists", e)
|
||||
|
||||
torch_dtype = save_weight(input_dir, output_dir, shard_size)
|
||||
save_config(input_dir, output_dir, torch_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(llamafy_qwen)
|
||||
@@ -1,654 +0,0 @@
|
||||
# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
|
||||
# Modified by hiyouga, to support attention mask, the alibi implementation is largely borrowed from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.utils import logging
|
||||
|
||||
from .configuration_baichuan import BaichuanConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from transformers.models.bloom.modeling_bloom._make_causal_mask
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
"""
|
||||
Make causal mask used for self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
||||
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
||||
seq_ids = torch.arange(target_length, device=device)
|
||||
mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask[:, :past_key_values_length] = False
|
||||
|
||||
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
# Copied from transformers.models.bloom.modeling_bloom._expand_mask
|
||||
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
||||
"""
|
||||
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
|
||||
"""
|
||||
batch_size, src_length = mask.shape
|
||||
tgt_length = tgt_length if tgt_length is not None else src_length
|
||||
|
||||
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
|
||||
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
|
||||
|
||||
|
||||
# Copied from transformers.models.bloom.modeling_bloom.build_alibi_tensor
|
||||
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
||||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||
`softmax(l+a) = softmax(l)`.
|
||||
|
||||
Args:
|
||||
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
|
||||
attention_mask (`torch.Tensor`):
|
||||
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
|
||||
num_heads (`int`, *required*):
|
||||
number of heads
|
||||
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
||||
dtype of the output tensor
|
||||
"""
|
||||
batch_size, seq_length = attention_mask.shape
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
|
||||
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
||||
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
||||
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
||||
# => the query_length dimension will then be broadcasted correctly
|
||||
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
||||
alibi = slopes[..., None] * arange_tensor
|
||||
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, epsilon=1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
|
||||
|
||||
return (self.weight * hidden_states).to(input_dtype)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.act_fn = ACT2FN[hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class BaichuanAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: "BaichuanConfig"):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.max_position_embeddings = config.model_max_length
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
|
||||
)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
self.beta = 1.0
|
||||
|
||||
self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
proj = self.W_pack(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
||||
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
|
||||
query_states = query_states.transpose(1, 2).reshape(bsz * self.num_heads, q_len, self.head_dim)
|
||||
key_states = key_states.permute(0, 2, 3, 1).reshape(bsz * self.num_heads, self.head_dim, q_len)
|
||||
value_states = value_states.transpose(1, 2).reshape(bsz * self.num_heads, q_len, self.head_dim)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
past_key, past_value = past_key_value
|
||||
key_states = torch.cat([past_key, key_states], dim=2)
|
||||
value_states = torch.cat([past_value, value_states], dim=1)
|
||||
|
||||
_, _, kv_seq_len = key_states.shape
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# [batch_size * num_heads, q_length, kv_length]
|
||||
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
||||
matmul_result = alibi.baddbmm(
|
||||
batch1=query_states,
|
||||
batch2=key_states,
|
||||
beta=self.beta,
|
||||
alpha=self.inv_norm_factor,
|
||||
)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(bsz, self.num_heads, q_len, kv_seq_len)
|
||||
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
|
||||
# [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16:
|
||||
attention_scores = attention_scores.to(torch.float)
|
||||
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
||||
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
||||
|
||||
# change view [batch_size x num_heads, q_length, kv_length]
|
||||
attention_probs_reshaped = attention_probs.view(bsz * self.num_heads, q_len, kv_seq_len)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
attn_output = torch.bmm(attention_probs_reshaped, value_states)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attention_probs = None
|
||||
|
||||
return attn_output, attention_probs, past_key_value
|
||||
|
||||
|
||||
class BaichuanLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: "BaichuanConfig"):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = BaichuanAttention(config=config)
|
||||
self.mlp = MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
alibi=alibi,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class BaichuanPreTrainedModel(PreTrainedModel):
|
||||
config_class = BaichuanConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BaichuanLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, BaichuanModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_standard_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
||||
num_heads, ...]))
|
||||
"""
|
||||
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||
num_heads = batch_size_times_num_heads // batch_size
|
||||
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
||||
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
||||
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_baichuan_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
||||
"""
|
||||
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||
batch_size_times_num_heads = batch_size * num_heads
|
||||
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
||||
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
|
||||
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
|
||||
class BaichuanModel(BaichuanPreTrainedModel):
|
||||
|
||||
def __init__(self, config: "BaichuanConfig"):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.n_head = config.num_attention_heads
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = config.gradient_checkpointing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
||||
return build_alibi_tensor(attention_mask, num_heads, dtype)
|
||||
|
||||
def _prepare_attn_mask(
|
||||
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
# create causal mask
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
combined_attention_mask = None
|
||||
device = attention_mask.device
|
||||
_, src_length = input_shape
|
||||
|
||||
if src_length > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, device=device, past_key_values_length=past_key_values_length
|
||||
)
|
||||
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You need to provide input_ids or inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[1]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
alibi = self.build_alibi_tensor(attention_mask, self.n_head, dtype=hidden_states.dtype)
|
||||
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
alibi=alibi,
|
||||
attention_mask=causal_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = BaichuanModel(config)
|
||||
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
# the cache may be in the standard format (e.g. in contrastive search)
|
||||
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
||||
past_key_values = self._convert_to_baichuan_cache(past_key_values)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def _reorder_cache(
|
||||
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
||||
"""
|
||||
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
||||
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
||||
beam_idx at every generation step.
|
||||
|
||||
Output shares the same memory storage as `past`.
|
||||
"""
|
||||
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
|
||||
|
||||
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
||||
device_to_beam_idx = {
|
||||
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
|
||||
}
|
||||
reordered_past = tuple(
|
||||
(
|
||||
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
||||
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
||||
)
|
||||
for layer_past in standardized_past
|
||||
)
|
||||
return self._convert_to_baichuan_cache(reordered_past)
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Quantizes fine-tuned models with AutoGPTQ (https://github.com/PanQiWei/AutoGPTQ).
|
||||
# Quantizes models with AutoGPTQ (https://github.com/PanQiWei/AutoGPTQ).
|
||||
# Usage: python quantize.py --input_dir path_to_llama_model --output_dir path_to_quant_model --data_file alpaca.json
|
||||
# --max_length 1024 --max_samples 1024
|
||||
# dataset format: instruction (string), input (string), output (string), history (List[string])
|
||||
|
||||
Reference in New Issue
Block a user