38 Commits

Author SHA1 Message Date
hiyouga
0b8e19b6a6 fix webui
Former-commit-id: cf4cd52d36894f53a6ec45d003f887771012e5b4
2023-08-01 12:11:37 +08:00
hiyouga
8e26eb374e fix RM save model
Former-commit-id: 8104cc2425431eb1cddccf3909855296116f922b
2023-08-01 11:56:17 +08:00
hiyouga
9bba01a033 use git lfs
Former-commit-id: 4886d0071751f68c5a2d926bd9fcee0c93337322
2023-08-01 10:14:08 +08:00
hiyouga
661890b8a1 release v0.1.4
Former-commit-id: 81f84aaf2e120e39edb28ef42893939fc9a184e2
2023-08-01 10:08:47 +08:00
hiyouga
772ad4ec6b fix inference
Former-commit-id: 55dc2bdd3eaa552c655e584fc3cbbf017c7bc3e7
2023-08-01 00:06:48 +08:00
hiyouga
6f65f8cb3b fix arg check
Former-commit-id: 2c5c73de9ebc88e2d04e80754781c94a571133a0
2023-07-31 23:48:57 +08:00
hiyouga
43e83548b9 update readme
Former-commit-id: d99cda254e5025ff3f968d256197ab031bfabef1
2023-07-31 23:42:32 +08:00
hiyouga
dd3f3e9749 support streaming data, fix #284 #274 #268
Former-commit-id: 819cc1353599e5fa45658bc56dd0dbe4b258b197
2023-07-31 23:33:00 +08:00
hiyouga
124f61b404 Update data_args.py
Former-commit-id: 41ac5455af195747ba369c3a6dc7d412a366d54d
2023-07-28 17:42:41 +08:00
hiyouga
e8748cc6f3 update readme
Former-commit-id: 14d20cd1fdcfd1f2842362f70472b666e5d48c7d
2023-07-28 17:36:00 +08:00
hiyouga
fafec8b7a5 fix #268
Former-commit-id: 1eee0207fb370bb9e234e9bd3f9a0c47d7d01bc9
2023-07-28 17:02:26 +08:00
hiyouga
030daca686 update dataset
Former-commit-id: 4a044aabbd19c92a9ae93c1c30536f5086fd47f9
2023-07-26 17:05:12 +08:00
hiyouga
ac587438f8 fix #242
Former-commit-id: 80a346e29beb49e8935b786e2af1059fdc4954b2
2023-07-25 17:04:02 +08:00
hiyouga
c145bbef3c update dataset
Former-commit-id: 4fc2c3293d91d8464527ebd1ddabe572c8355616
2023-07-23 20:01:43 +08:00
hiyouga
745c46ee04 Update README_zh.md
Former-commit-id: 9d3c8803a34c06a2a5512fec3f841d7efcab3e3c
2023-07-22 14:31:16 +08:00
hiyouga
a707f5b502 update readme, fix web ui postprocess
Former-commit-id: ba51ab3379100108f7b52a3c2444ccdd99e8a6ef
2023-07-22 14:29:22 +08:00
hoshi-hiyouga
dc2e801077 Merge pull request #221 from mrhan1993/main
根据GLM Efficient Tuning添加中文README,web添加了server_port参数

Former-commit-id: 948f2abcb818211ee99d4c140e26044ca591369f
2023-07-22 13:04:25 +08:00
NULL
b56d5108b2 Merge branch 'hiyouga:main' into main
Former-commit-id: 8244c5b554ad0823e8ebea3d5583a6ecf9a66d2d
2023-07-21 17:00:26 +08:00
mrhan1993
8e6b7034fe 根据GLM Efficient Tuning添加中文README,web添加了server_port
Former-commit-id: 29e3acd23eafd891667d7a860ec544a5b05d3c33
2023-07-21 16:57:58 +08:00
hiyouga
dad7ca6633 release v0.1.3
Former-commit-id: 62c68bcbf591516e8f90b47810bea6f710fd23f6
2023-07-21 16:48:34 +08:00
hiyouga
a1468139a5 fix save function
Former-commit-id: 1d6beb0c8490a7531ffdf7a2819410597b200d12
2023-07-21 14:09:07 +08:00
hiyouga
49c90044ce Update runner.py
Former-commit-id: d7309deae46cfcdeeee79f54736df9b7e93b79ce
2023-07-21 13:35:19 +08:00
hiyouga
0f7cdac207 update web UI, support rm predict #210
Former-commit-id: 92cc6b655dc91b94d5bf9d8618c3b57d5cf94333
2023-07-21 13:27:27 +08:00
hiyouga
c4e9694c6e release v0.1.2
Former-commit-id: 04aad91b71cc3a1acaf1bcec4304ce6b2098f7dc
2023-07-20 22:33:59 +08:00
hiyouga
2006a96570 fix api
Former-commit-id: 4c3e8be325045e432b31c519132123c7b0689262
2023-07-20 22:14:54 +08:00
hoshi-hiyouga
5dcd95645f Merge pull request #213 from Ehco1996/patch-1
feat: support pass args before init web app
Former-commit-id: b0612c05bc10c281c0a95e08c5517c3fb0a72029
2023-07-20 22:12:07 +08:00
hiyouga
9b3304b054 update UI, fix #212
Former-commit-id: ac92c2bd7c47353759474fad9412f21b38c65501
2023-07-20 22:09:06 +08:00
Ehco
e580d4ef41 feat: support pass args before init web app
as title

Former-commit-id: 434a5077288927e0be15cd066ca3e562111fad4d
2023-07-20 21:49:26 +08:00
hiyouga
64db4abc68 Update README.md
Former-commit-id: 6dc67a495ec7d9fdc2574bae92063ed8a9099725
2023-07-20 17:23:16 +08:00
hiyouga
5ba0b80e5c simplify code
Former-commit-id: d3731754ab7c28ae81f60784e0e4213f279d93fe
2023-07-20 15:08:57 +08:00
hiyouga
7a43ff3d89 tiny fix
Former-commit-id: 22b1be7bbb9e7bd863acb88bf7365090b1b8235d
2023-07-19 22:53:46 +08:00
hiyouga
7e1a1d141a fix #199
Former-commit-id: 7fc778b49bc17688aca39fffe01f9d33e03e0c28
2023-07-19 22:51:29 +08:00
hiyouga
6d881f161b add datasets
Former-commit-id: 02e4b47dea1b25905c61f2ace88bab112610f021
2023-07-19 20:59:15 +08:00
hiyouga
a02b3e6192 fix #196
Former-commit-id: 85fd82926db345a590a7fb32c0e352a1d2f025c3
2023-07-19 17:35:38 +08:00
hiyouga
bcdee9fc19 fix #194
Former-commit-id: 9792921531efefb4bcddbde4380169a78fe064a6
2023-07-19 17:07:33 +08:00
hiyouga
8b688251be support LLaMA-2
Former-commit-id: 04dfda054855ee9256586aacbd382f8fb0bfed04
2023-07-19 16:42:14 +08:00
hiyouga
718f3382ad add LLaMA2 template
Former-commit-id: 246421bd35cf7bb2203ac4fc924e6cd1c292954d
2023-07-19 00:44:49 +08:00
hiyouga
dc8283d3d7 fix API
Former-commit-id: 9b10c9a12e33ab897056ecc61d977d221c19141b
2023-07-19 00:01:14 +08:00
68 changed files with 1535 additions and 708 deletions

1
.gitattributes vendored
View File

@@ -1,2 +1,3 @@
# Auto detect text files and perform LF normalization # Auto detect text files and perform LF normalization
* text=auto * text=auto
*.json filter=lfs diff=lfs merge=lfs -text

141
README.md
View File

@@ -8,31 +8,40 @@
👋 Join our [WeChat](assets/wechat.jpg). 👋 Join our [WeChat](assets/wechat.jpg).
\[ English | [中文](README_zh.md) \]
## Changelog ## Changelog
[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset.
[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details.
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development. [23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Please replace the Baichuan-13B model file with `tests/modeling_baichuan.py` and try `--model_name_or_path path_to_baichuan_model` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model. [23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested. [23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--prompt_template intern` argument when you are using the InternLM-chat model. [23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model.
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model. [23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
[23/06/29] We provide a **reproducible example** of training a chat model using instruction-following datasets, see this [HuggingFace Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details. [23/06/29] We provide a **reproducible example** of training a chat model using instruction-following datasets, see this [Hugging Face Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details.
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**. [23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
[23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model. If you want to train with RTX3090, use `git checkout baichuan-7b-rtx3090` to switch to the `baichuan-7b-rtx3090` branch and try the `--baichuan_rtx_gpu true` argument. (Other RTX series GPUs can also be tried) [23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model.
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature) [23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
[23/05/31] Now we support training the **BLOOM & BLOOMZ** models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model. [23/05/31] Now we support training the **BLOOM & BLOOMZ** models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model.
## Supported Models ## Supported Models
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B) - [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B)
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B) - [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B) - [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B) - [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
@@ -57,36 +66,41 @@
## Provided Datasets ## Provided Datasets
- For pre-training: - For pre-training:
- [Wiki Demo](data/wiki_demo.txt) - [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- For supervised fine-tuning: - For supervised fine-tuning:
- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - [Self-cognition (zh)](data/self_cognition.json)
- [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) - [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) - [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) - [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) - [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [Web QA (Chinese)](https://huggingface.co/datasets/suolyer/webqa) - [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [UltraChat](https://github.com/thunlp/UltraChat) - [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1) - [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [Open Assistant (Chinese)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [WebNovel (Chinese)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- For reward model training: - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- [Open Assistant (Chinese)](https://huggingface.co/datasets/OpenAssistant/oasst1) - For reward modelling:
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [GPT-4 Generated Data (Chinese)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
Please refer to [data/README.md](data/README.md) for details. Please refer to [data/README.md](data/README.md) for details.
Some datasets require confirmation before using them, so we recommend logging in with your HuggingFace account using these commands. Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
```bash ```bash
pip install --upgrade huggingface_hub pip install --upgrade huggingface_hub
@@ -103,12 +117,6 @@ huggingface-cli login
And **powerful GPUs**! And **powerful GPUs**!
If you want to enable quantized LoRA (QLoRA) on the Windows platform, you should install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
## Getting Started ## Getting Started
### Data Preparation (optional) ### Data Preparation (optional)
@@ -120,6 +128,7 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
### Dependence Installation (optional) ### Dependence Installation (optional)
```bash ```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10 conda create -n llama_etuning python=3.10
conda activate llama_etuning conda activate llama_etuning
@@ -127,12 +136,20 @@ cd LLaMA-Efficient-Tuning
pip install -r requirements.txt pip install -r requirements.txt
``` ```
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
### All-in-one Web UI ### All-in-one Web UI
```bash ```bash
python src/train_web.py python src/train_web.py
``` ```
Currently the web UI only supports training on a single GPU.
### (Continually) Pre-Training ### (Continually) Pre-Training
```bash ```bash
@@ -141,6 +158,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_train \ --do_train \
--dataset wiki_demo \ --dataset wiki_demo \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--output_dir path_to_pt_checkpoint \ --output_dir path_to_pt_checkpoint \
--overwrite_cache \ --overwrite_cache \
@@ -163,6 +181,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_train \ --do_train \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--output_dir path_to_sft_checkpoint \ --output_dir path_to_sft_checkpoint \
--overwrite_cache \ --overwrite_cache \
@@ -185,7 +204,10 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_train \ --do_train \
--dataset comparison_gpt4_en \ --dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \ --output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
@@ -206,7 +228,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_train \ --do_train \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \ --finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \ --checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \ --reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \ --output_dir path_to_ppo_checkpoint \
@@ -217,7 +241,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--save_steps 1000 \ --save_steps 1000 \
--learning_rate 1e-5 \ --learning_rate 1e-5 \
--num_train_epochs 1.0 \ --num_train_epochs 1.0 \
--resume_lora_training False \
--plot_loss --plot_loss
``` ```
@@ -260,34 +283,57 @@ use_cpu: false
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \ --stage sft \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--do_eval \ --do_eval \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \ --output_dir path_to_eval_result \
--per_device_eval_batch_size 8 \ --per_device_eval_batch_size 8 \
--max_samples 50 \ --max_samples 100 \
--predict_with_generate --predict_with_generate
``` ```
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation. We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
### Predict
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--do_predict \
--dataset alpaca_gpt4_en \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
```
### API Demo ### API Demo
```bash ```bash
python src/api_demo.py \ python src/api_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
See `http://localhost:8000/docs` for API documentation. Visit `http://localhost:8000/docs` for API documentation.
### CLI Demo ### CLI Demo
```bash ```bash
python src/cli_demo.py \ python src/cli_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
@@ -296,6 +342,8 @@ python src/cli_demo.py \
```bash ```bash
python src/web_demo.py \ python src/web_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
@@ -304,10 +352,18 @@ python src/web_demo.py \
```bash ```bash
python src/export_model.py \ python src/export_model.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_export --output_dir path_to_export
``` ```
## TODO
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)).
- [ ] Implementing multi-query attention for faster inference.
- [ ] Supporting full-parameter RLHF training.
## License ## License
This repository is licensed under the [Apache-2.0 License](LICENSE). This repository is licensed under the [Apache-2.0 License](LICENSE).
@@ -315,9 +371,10 @@ This repository is licensed under the [Apache-2.0 License](LICENSE).
Please follow the model licenses to use the corresponding model weights: Please follow the model licenses to use the corresponding model weights:
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) - [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
- [LLaMA-2](https://ai.meta.com/llama/license/)
- [BLOOM](https://huggingface.co/spaces/bigscience/license) - [BLOOM](https://huggingface.co/spaces/bigscience/license)
- [Falcon](LICENSE) - [Falcon](LICENSE)
- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) - [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license) - [InternLM](https://github.com/InternLM/InternLM#open-source-license)
## Citation ## Citation

399
README_zh.md Normal file
View File

@@ -0,0 +1,399 @@
# LLaMA Efficient Tuning
[![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/LLaMA-Efficient-Tuning?style=social)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/stargazers)
[![GitHub Code License](https://img.shields.io/github/license/hiyouga/LLaMA-Efficient-Tuning)](LICENSE)
[![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Efficient-Tuning)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/commits/main)
[![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/)
[![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Efficient-Tuning/pulls)
👋 加入我们的[微信群](assets/wechat.jpg)。
\[ [English](README.md) | 中文 \]
## 更新日志
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming``--max_steps 100` 参数来流式加载数据集。
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。
[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model``--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。
[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数。
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b``--lora_target query_key_value` 参数。
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft)。
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入任意基于 ChatGPT 的应用中。
[23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B``--lora_target W_pack` 参数。
[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 [QLoRA](https://github.com/artidoro/qlora))。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
[23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt``--lora_target query_key_value` 参数。
## 模型
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B)
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
- [InternLM](https://github.com/InternLM/InternLM) (7B)
## 微调方法
- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
- 全参数微调
- 部分参数微调
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [指令监督微调](https://arxiv.org/abs/2109.01652)
- 全参数微调
- 部分参数微调
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [人类反馈的强化学习RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
## 数据集
- 用于二次预训练:
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- 用于指令监督微调:
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- 用于奖励模型训练:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
使用方法请参考 [data/README.md](data/README_zh.md) 文件。
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
```bash
pip install --upgrade huggingface_hub
huggingface-cli login
```
## 软件依赖
- Python 3.8+ 和 PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
- jieba, rouge-chinese 和 nltk (用于评估)
- gradio 和 matplotlib (用于网页端交互)
- uvicorn, fastapi 和 sse-starlette (用于 API)
以及 **强而有力的 GPU**
## 如何使用
### 数据准备(可跳过)
关于数据集文件的格式,请参考 `data/example_dataset` 文件夹的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
注意:使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README.md`
### 环境搭建(可跳过)
```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10
conda activate llama_etuning
cd LLaMA-Efficient-Tuning
pip install -r requirements.txt
```
如果要在 Windows 平台上开启量化 LoRAQLoRA需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.1.
```bash
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
### 浏览器一键微调/测试
```bash
python src/train_web.py
```
目前网页 UI 仅支持单卡训练。
### 二次预训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \
--model_name_or_path path_to_your_model \
--do_train \
--dataset wiki_demo \
--template default \
--finetuning_type lora \
--output_dir path_to_pt_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
### 指令监督微调
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--plot_loss \
--fp16
```
### 奖励模型训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### RLHF 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss
```
### 多 GPU 分布式训练
```bash
accelerate config # 首先配置分布式环境
accelerate launch src/train_bash.py # 参数同上
```
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数微调的 Accelerate 配置示例</summary>
```yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 4
gradient_clipping: 0.5
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
</details>
### 指标评估BLEU分数和汉语ROUGE分数
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--do_eval \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
```
我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1``--max_target_length 128` 参数。
### 模型预测
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_your_model \
--do_predict \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
```
### API 服务
```bash
python src/api_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
关于 API 文档请见 `http://localhost:8000/docs`
### 命令行测试
```bash
python src/cli_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### 浏览器测试
```bash
python src/web_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### 导出微调模型
```bash
python src/export_model.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export
```
## TODO
- [ ] 实现 flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention))。
- [ ] 在推理阶段使用 Multi-query attention 进行加速。
- [ ] 支持 RLHF 的全参数微调。
## 协议
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
使用模型权重时,请遵循对应的模型协议:
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
- [LLaMA-2](https://ai.meta.com/llama/license/)
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
- [Falcon](LICENSE)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
## 引用
如果您觉得此项目有帮助,请考虑以下列格式引用
```bibtex
@Misc{llama-efficient-tuning,
title = {LLaMA Efficient Tuning},
author = {hiyouga},
howpublished = {\url{https://github.com/hiyouga/LLaMA-Efficient-Tuning}},
year = {2023}
}
```
## 致谢
本项目是 [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning) 的同类项目。采用了类似的代码结构和训练方法。
## Star History
![Star History Chart](https://api.star-history.com/svg?repos=hiyouga/LLaMA-Efficient-Tuning&type=Date)

View File

@@ -1,53 +1,18 @@
Data format in `dataset_info.json`: If you are using a custom dataset, please provide your dataset definition in the following format in `dataset_info.json`.
```json ```json
"dataset_name": { "dataset_name": {
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)", "hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)", "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)", "file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)", "file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
"columns": { "columns": {
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)", "prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
"query": "the name of the column in the datasets containing the queries. (default: input)", "query": "the name of the column in the datasets containing the queries. (default: input)",
"response": "the name of the column in the datasets containing the responses. (default: output)", "response": "the name of the column in the datasets containing the responses. (default: output)",
"history": "the name of the column in the datasets containing the history of chat. (default: None)" "history": "the name of the column in the datasets containing the history of chat. (default: None)"
} }
} }
``` ```
`dataset_info.json` 中的数据集定义格式: where the `prompt` and `response` columns should contain non-empty values. The `query` column will be concatenated with the `prompt` column and used as input for the model. The `history` column should contain a list where each element is a string tuple representing a query-response pair.
```json
"数据集名称": {
"hf_hub_url": "HuggingFace上的项目地址若指定则忽略下列三个参数",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选",
"columns": {
"prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input",
"response": "数据集代表回答的表头名称默认output",
"history": "数据集代表历史对话的表头名称默认None"
}
}
```
部分预置数据集简介:
| 数据集名称 | 规模 | 描述 |
| --- | --- | --- |
| [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) | 52k | 斯坦福大学开源的 Alpaca 数据集,训练了 Alpaca 这类早期基于 LLaMA 的模型 |
| [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) | 51k | 使用 ChatGPT 翻译的 Alpaca 数据集 |
| [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) | 100k+ | 基于 GPT-4 的 self-instruction 数据集 |
| [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN) | 2m | 包含约 200 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
| [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN) | 1m | 包含约 100 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
| [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) | 500k | 包含约 50 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
| [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) | 400k | 包含约 40 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的个性化角色对话数据,包含角色介绍 |
| [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) | 250k | 包含约 25 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文数学题数据,包含解题过程 |
| [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) | 800k | 包含约 80 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的用户与助手的多轮对话 |
| [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) | 100k+ | 包含日文、简繁体中文、英文等多类数据,数据集原用于 Guanaco 模型训练 |
| [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) | 1.1M | 中文对话大模型 firefly流萤的中文数据集包含多个 NLP 任务 |
| [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) | 20k | 英文代码生成任务数据集 |
| [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) | 6M | 用于微调的指令数据集集合 |
| [Web QA](https://huggingface.co/datasets/suolyer/webqa) | 36k | 百度知道汇集的中文问答数据集 |
| [UltraChat](https://github.com/thunlp/UltraChat) | 1.57M | 清华 NLP 发布的大规模多轮对话数据集 |
BELLE 数据集是由 ChatGPT 产生的数据集,不保证数据准确性,所有类 GPT 模型产生的 self-instruction 数据集均不能保证其准确性。

18
data/README_zh.md Normal file
View File

@@ -0,0 +1,18 @@
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中以如下格式提供您的数据集定义。
```json
"数据集名称": {
"hf_hub_url": "HuggingFace上的项目地址若指定则忽略下列三个参数",
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
"file_sha1": "数据集文件的SHA-1哈希值可选",
"columns": {
"prompt": "数据集代表提示词的表头名称默认instruction",
"query": "数据集代表请求的表头名称默认input",
"response": "数据集代表回答的表头名称默认output",
"history": "数据集代表历史对话的表头名称默认None"
}
}
```
其中 `prompt``response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复。

View File

@@ -1 +0,0 @@
3779ddbc040543ab1834ef216c983d6fcc06cc9a

View File

@@ -1 +0,0 @@
fc9a6a3458caca2af8dafc6181773fe10c6d8657

View File

@@ -1 +0,0 @@
25508714b7879a1e5a6764ba7f979a980f549f1a

View File

@@ -1 +0,0 @@
7cb6a7d11455bddc3d495750a2392683d775b184

View File

@@ -1 +0,0 @@
f5cb08305ff5dc9c17a09809c54c8c8834aadc70

View File

@@ -1 +0,0 @@
aee47b7b443496e37808d7f34ef10403ff99bcc3

View File

@@ -1 +0,0 @@
274079ea921762be356de85b18f13fa60b7ba8cb

View File

@@ -1 +0,0 @@
0a57fbc1d8cb08a8cd71c5eb8425cf59206ffed6

View File

@@ -1,2 +0,0 @@
{"id": 0,"title": "大卫·亨利","content": "大卫·亨利\n\n大卫·克莱顿·亨利David Clayton Henrie美国演员。近来在迪士尼频道原创电视影集《少年魔法师》Wizards of Waverly Place当中演出贾斯汀·鲁索Justin Russo一角。\n\n大卫·亨利出生在加州Mission Viejo在凤凰城长大。他的胞弟劳伦斯·亨利Lorenzo Henrie也是演员。大卫·亨利就读夏安传统学校。家中是信奉罗马天主教。 \n\n大卫在2007年拍摄少年魔法师期间认识女演员露西·海尔Lucy Hale之后与其交往于2009年分手。\n\n10岁时大卫·亨利和SAG在凤凰城签订了合约并开始走出去试镜。 9岁的时候在沙加缅度进行商业拍摄SAG董事建议大卫·亨利搬到洛杉矶。在10岁那年夏天他和他的家人搬到了好莱坞。他预定他的前2支商业试镜扮演主要角色为汉堡王和桂格燕麦。他初演电视节目为Providence。 \n\n到了13岁大卫有了他的第一次重大突破在福克斯公司的喜剧The Pitts饰演 Petey Pitt一角。大卫下出作品为的Hallmark movie为Monster Maker和琳达布莱儿、乔治甘迺迪共同演出并要求回来Hallmark movie公司。 \n\n在18岁时大卫得到了迪士尼频道原创系列演出机会该节目2007年10月12日首播。大卫2008年参加了迪士尼频道的游戏节目。他是绿色团队的队长隔年为旋风队队长。他在迪士尼原创电影《少年魔法师》之后在《酷爸的疯狂假期》中有饰演一角。\n"}
{"id": 1,"title": "大卫·亨利","content": "大卫·亨利\n\n大卫·克莱顿·亨利David Clayton Henrie美国演员。近来在迪士尼频道原创电视影集《少年魔法师》Wizards of Waverly Place当中演出贾斯汀·鲁索Justin Russo一角。\n\n大卫·亨利出生在加州Mission Viejo在凤凰城长大。他的胞弟劳伦斯·亨利Lorenzo Henrie也是演员。大卫·亨利就读夏安传统学校。家中是信奉罗马天主教。 \n\n大卫在2007年拍摄少年魔法师期间认识女演员露西·海尔Lucy Hale之后与其交往于2009年分手。\n\n10岁时大卫·亨利和SAG在凤凰城签订了合约并开始走出去试镜。 9岁的时候在沙加缅度进行商业拍摄SAG董事建议大卫·亨利搬到洛杉矶。在10岁那年夏天他和他的家人搬到了好莱坞。他预定他的前2支商业试镜扮演主要角色为汉堡王和桂格燕麦。他初演电视节目为Providence。 \n\n到了13岁大卫有了他的第一次重大突破在福克斯公司的喜剧The Pitts饰演 Petey Pitt一角。大卫下出作品为的Hallmark movie为Monster Maker和琳达布莱儿、乔治甘迺迪共同演出并要求回来Hallmark movie公司。 \n\n在18岁时大卫得到了迪士尼频道原创系列演出机会该节目2007年10月12日首播。大卫2008年参加了迪士尼频道的游戏节目。他是绿色团队的队长隔年为旋风队队长。他在迪士尼原创电影《少年魔法师》之后在《酷爸的疯狂假期》中有饰演一角。\n"}

View File

@@ -1,8 +1,8 @@
torch>=1.13.1 torch>=1.13.1
transformers>=4.29.1 transformers>=4.29.1
datasets>=2.12.0 datasets>=2.12.0
accelerate>=0.19.0 accelerate>=0.21.0
peft>=0.3.0 peft>=0.4.0
trl>=0.4.7 trl>=0.4.7
sentencepiece sentencepiece
jieba jieba
@@ -10,7 +10,7 @@ rouge-chinese
nltk nltk
gradio>=3.36.0 gradio>=3.36.0
uvicorn uvicorn
pydantic pydantic==1.10.11
fastapi fastapi==0.95.1
sse-starlette sse-starlette
matplotlib matplotlib

View File

@@ -5,9 +5,16 @@
import uvicorn import uvicorn
from llmtuner import create_app from llmtuner import ChatModel
from llmtuner.api.app import create_app
from llmtuner.tuner import get_infer_args
def main():
chat_model = ChatModel(*get_infer_args())
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
if __name__ == "__main__": if __name__ == "__main__":
app = create_app() main()
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

View File

@@ -2,7 +2,8 @@
# Implements stream chat in command line for fine-tuned models. # Implements stream chat in command line for fine-tuned models.
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint # Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
from llmtuner import ChatModel, get_infer_args from llmtuner import ChatModel
from llmtuner.tuner import get_infer_args
def main(): def main():

View File

@@ -2,7 +2,7 @@
# Exports the fine-tuned model. # Exports the fine-tuned model.
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model # Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
from llmtuner import get_train_args, load_model_and_tokenizer from llmtuner.tuner import get_train_args, load_model_and_tokenizer
def main(): def main():

View File

@@ -1,7 +1,4 @@
from llmtuner.api import create_app
from llmtuner.chat import ChatModel from llmtuner.chat import ChatModel
from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo
from llmtuner.webui import create_ui
__version__ = "0.1.0" __version__ = "0.1.4"

View File

@@ -1 +0,0 @@
from llmtuner.api.app import create_app

View File

@@ -1,4 +1,3 @@
import json
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@@ -31,9 +30,7 @@ async def lifespan(app: FastAPI): # collects GPU memory
torch_gc() torch_gc()
def create_app(): def create_app(chat_model: ChatModel) -> FastAPI:
chat_model = ChatModel(*get_infer_args())
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware( app.add_middleware(
@@ -96,7 +93,7 @@ def create_app():
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
@@ -110,7 +107,7 @@ def create_app():
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
@@ -118,12 +115,13 @@ def create_app():
finish_reason=Finish.STOP finish_reason=Finish.STOP
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield "[DONE]" yield "[DONE]"
return app return app
if __name__ == "__main__": if __name__ == "__main__":
app = create_app() chat_model = ChatModel(*get_infer_args())
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

View File

@@ -1,37 +1,54 @@
import torch import torch
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.template import Template from llmtuner.extras.template import get_template
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from llmtuner.tuner import load_model_and_tokenizer from llmtuner.tuner import load_model_and_tokenizer
if TYPE_CHECKING:
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
class ChatModel: class ChatModel:
def __init__( def __init__(
self, self,
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
generating_args: GeneratingArguments generating_args: "GeneratingArguments"
) -> None: ) -> None:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.template = Template(data_args.prompt_template)
self.source_prefix = data_args.source_prefix if data_args.source_prefix else "" if torch.cuda.device_count() > 1:
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model))
self.model = dispatch_model(self.model, device_map)
else:
self.model = self.model.cuda()
self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix
self.generating_args = generating_args self.generating_args = generating_args
def process_args( def process_args(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
prefix = prefix if prefix else self.source_prefix prefix = prefix or self.source_prefix
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt") prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token)
inputs = self.tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.model.device) inputs = inputs.to(self.model.device)
prompt_length = len(inputs["input_ids"][0]) prompt_length = len(inputs["input_ids"][0])
do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None) temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None) top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None) top_k = input_kwargs.pop("top_k", None)
@@ -42,6 +59,7 @@ class ChatModel:
gen_kwargs = self.generating_args.to_dict() gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict( gen_kwargs.update(dict(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
temperature=temperature or gen_kwargs["temperature"], temperature=temperature or gen_kwargs["temperature"],
top_p=top_p or gen_kwargs["top_p"], top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"], top_k=top_k or gen_kwargs["top_k"],
@@ -61,7 +79,11 @@ class ChatModel:
@torch.inference_mode() @torch.inference_mode()
def chat( def chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Tuple[str, Tuple[int, int]]: ) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs) gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs) generation_output = self.model.generate(**gen_kwargs)
@@ -72,7 +94,11 @@ class ChatModel:
@torch.inference_mode() @torch.inference_mode()
def stream_chat( def stream_chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None,
**input_kwargs
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs) gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
@@ -81,5 +107,4 @@ class ChatModel:
thread = Thread(target=self.model.generate, kwargs=gen_kwargs) thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
thread.start() thread.start()
for new_text in streamer: yield from streamer
yield new_text

View File

@@ -1,2 +1,3 @@
from llmtuner.dsets.loader import get_dataset from llmtuner.dsets.loader import get_dataset
from llmtuner.dsets.preprocess import preprocess_dataset from llmtuner.dsets.preprocess import preprocess_dataset
from llmtuner.dsets.utils import split_dataset

View File

@@ -1,63 +0,0 @@
import os
import json
import time
from datetime import timedelta
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments
)
class LogCallback(TrainerCallback):
def __init__(self, runner=None):
self.runner = runner
self.start_time = time.time()
self.tracker = {}
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
r"""
Event called at the beginning of a training step. If using gradient accumulation, one training step
might take several inputs.
"""
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
"""
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
r"""
Event called after logging the last logs.
"""
if "loss" not in state.log_history[-1]:
return
cur_time = time.time()
cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
remaining_steps = state.max_steps - cur_steps
remaining_time = remaining_steps * avg_time_per_step
self.tracker = {
"current_steps": cur_steps,
"total_steps": state.max_steps,
"loss": state.log_history[-1].get("loss", None),
"reward": state.log_history[-1].get("reward", None),
"learning_rate": state.log_history[-1].get("learning_rate", None),
"epoch": state.log_history[-1].get("epoch", None),
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
"remaining_time": str(timedelta(seconds=int(remaining_time)))
}
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(self.tracker) + "\n")

View File

@@ -1,40 +1,50 @@
import os import os
import hashlib import hashlib
from typing import List from typing import TYPE_CHECKING, List, Optional
from datasets import Dataset, concatenate_datasets, load_dataset from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, DataArguments
if TYPE_CHECKING:
from datasets import Dataset
from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__) logger = get_logger(__name__)
EXT2TYPE = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def get_dataset( def get_dataset(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments data_args: "DataArguments"
) -> Dataset: ) -> "Dataset":
def checksum(file_path, hash):
with open(file_path, "rb") as datafile:
binary_data = datafile.read()
sha1 = hashlib.sha1(binary_data).hexdigest()
if sha1 != hash:
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
ext2type = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
max_samples = data_args.max_samples max_samples = data_args.max_samples
all_datasets: List[Dataset] = [] # support multiple datasets all_datasets: List["Dataset"] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list: for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub": if dataset_attr.load_from == "hf_hub":
@@ -47,60 +57,56 @@ def get_dataset(
data_path = None data_path = None
data_files: List[str] = [] data_files: List[str] = []
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
if data_path is None: if data_path is None:
data_path = ext2type.get(data_files[0].split(".")[-1], None) data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
else: else:
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match." assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = ext2type.get(data_files[0].split(".")[-1], None) data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
else: else:
raise ValueError("File not found.") raise ValueError("File not found.")
assert data_path, "File extension must be txt, csv, json or jsonl." assert data_path, "File extension must be txt, csv, json or jsonl."
checksum(data_files, dataset_attr.dataset_sha1)
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
checksum(data_files[0], dataset_attr.dataset_sha1)
else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
else: else:
raise NotImplementedError raise NotImplementedError
raw_datasets = load_dataset( dataset = load_dataset(
data_path, data_path,
data_files=data_files, data_files=data_files,
split=data_args.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
streaming=data_args.streaming,
use_auth_token=True if model_args.use_auth_token else None use_auth_token=True if model_args.use_auth_token else None
) )
dataset = raw_datasets[data_args.split]
if max_samples is not None: if max_samples is not None:
max_samples_temp = min(len(dataset), max_samples) max_samples_temp = min(len(dataset), max_samples)
dataset = dataset.select(range(max_samples_temp)) dataset = dataset.select(range(max_samples_temp))
dummy_data = [None] * len(dataset) for column_name in ["prompt", "query", "response", "history"]: # align datasets
prefix_data = [dataset_attr.source_prefix] * len(dataset) if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
for column_name, target_name in [ dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
("prompt_column", "prompt"),
("query_column", "query"), if dataset_attr.source_prefix: # add prefix
("response_column", "response"), dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix})
("history_column", "history")
]: # every dataset will have 4 columns same as each other
if getattr(dataset_attr, column_name) != target_name:
if getattr(dataset_attr, column_name):
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
else: # None or empty string
dataset = dataset.add_column(target_name, dummy_data)
dataset = dataset.add_column("prefix", prefix_data)
all_datasets.append(dataset) all_datasets.append(dataset)
if len(data_args.dataset_list) == 1: if len(data_args.dataset_list) == 1:
all_datasets = all_datasets[0] return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
return interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)
else: else:
all_datasets = concatenate_datasets(all_datasets) raise ValueError("Unknown mixing strategy.")
return all_datasets

View File

@@ -1,65 +1,63 @@
from typing import Literal from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
from itertools import chain from itertools import chain
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from datasets import Dataset
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.template import Template from llmtuner.extras.template import get_template
from llmtuner.hparams import DataArguments
if TYPE_CHECKING:
from datasets import Dataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
def preprocess_dataset( def preprocess_dataset(
dataset: Dataset, dataset: "Dataset",
tokenizer: PreTrainedTokenizer, tokenizer: "PreTrainedTokenizer",
data_args: DataArguments, data_args: "DataArguments",
training_args: Seq2SeqTrainingArguments, training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"] stage: Literal["pt", "sft", "rm", "ppo"]
) -> Dataset: ) -> "Dataset":
column_names = list(dataset.column_names or [])
template = get_template(data_args.template)
column_names = list(dataset.column_names) def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
prompt_template = Template(data_args.prompt_template)
# support question with a single answer or multiple answers
def get_dialog(examples):
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]: query, response = examples["prompt"][i], examples["response"][i]
query, answer = examples["prompt"][i], examples["response"][i] query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
query = query + "\n" + examples["query"][i] if examples["query"][i] else query history = history if "history" in examples and examples["history"][i] else []
prefix = examples["prefix"][i] if examples["prefix"][i] else "" prefix = prefix if "prefix" in examples and examples["prefix"][i] else ""
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix) yield query, response, history, prefix
yield dialog
def preprocess_pretrain_dataset(examples): def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>) # build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"] tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False)
concatenated_ids = list(chain(*text_ids)) concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_ids) total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.max_source_length - 1 block_size = data_args.max_source_length
# we drop the small remainder, and if the total_length < block_size, we exclude this batch # we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // block_size) * block_size total_length = (total_length // block_size) * block_size
# split by chunks of max_source_length # split by chunks of max_source_length
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size] result = {
for i in range(0, total_length, block_size)] k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
return { for k, t in concatenated_examples.items()
"input_ids": result,
"labels": result.copy()
} }
result["labels"] = result["input_ids"].copy()
return result
def preprocess_supervised_dataset(examples): def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for input with history, we build multiple input-label pairs just like: # for input with history, we build multiple input-label pairs just like:
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112 # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
model_inputs = {"input_ids": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
max_length = data_args.max_source_length + data_args.max_target_length max_length = data_args.max_source_length + data_args.max_target_length
for dialog in get_dialog(examples): for query, response, history, prefix in construct_example(examples):
input_ids, labels = [], [] input_ids, labels = [], []
for i in range(len(dialog) // 2): for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)):
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0)) source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0))
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False) target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
@@ -73,19 +71,20 @@ def preprocess_dataset(
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
model_inputs["labels"].append(labels) model_inputs["labels"].append(labels)
return model_inputs return model_inputs
def preprocess_unsupervised_dataset(examples): def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X` and labels with format `<bos> Y` # build inputs with format `<bos> X` and labels with format `<bos> Y`
model_inputs = {"input_ids": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for dialog in get_dialog(examples): for query, response, history, prefix in construct_example(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1] prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
target_ids = tokenizer.encode(text=answer, add_special_tokens=True) target_ids = tokenizer.encode(text=response, add_special_tokens=True)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
@@ -93,6 +92,7 @@ def preprocess_dataset(
target_ids = target_ids[:data_args.max_target_length] target_ids = target_ids[:data_args.max_target_length]
model_inputs["input_ids"].append(source_ids) model_inputs["input_ids"].append(source_ids)
model_inputs["attention_mask"].append([1] * len(source_ids))
model_inputs["labels"].append(target_ids) model_inputs["labels"].append(target_ids)
return model_inputs return model_inputs
@@ -100,12 +100,12 @@ def preprocess_dataset(
def preprocess_pairwise_dataset(examples): def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>` # build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []} model_inputs = {"accept_ids": [], "reject_ids": []}
for dialog in get_dialog(examples): for query, response, history, prefix in construct_example(examples):
prompt, answer = "".join(dialog[:-1]), dialog[-1] prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
@@ -141,32 +141,44 @@ def preprocess_dataset(
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt": if stage == "pt":
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_pretrain_dataset preprocess_function = preprocess_pretrain_dataset
elif stage == "sft": elif stage == "sft" and not training_args.predict_with_generate:
preprocess_function = preprocess_unsupervised_dataset \ dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
if training_args.predict_with_generate else preprocess_supervised_dataset preprocess_function = preprocess_supervised_dataset
elif stage == "rm": elif stage == "rm":
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
preprocess_function = preprocess_pairwise_dataset preprocess_function = preprocess_pairwise_dataset
elif stage == "ppo": else:
dataset = dataset.filter(lambda example: example["prompt"])
preprocess_function = preprocess_unsupervised_dataset preprocess_function = preprocess_unsupervised_dataset
with training_args.main_process_first(desc="dataset map pre-processing"): with training_args.main_process_first(desc="dataset map pre-processing"):
kwargs = {}
if not data_args.streaming:
kwargs = dict(
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
desc="Running tokenizer on dataset"
)
dataset = dataset.map( dataset = dataset.map(
preprocess_function, preprocess_function,
batched=True, batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names, remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache, **kwargs
desc="Running tokenizer on dataset"
) )
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
if stage == "pt": if stage == "pt":
print_unsupervised_dataset_example(dataset[0]) print_unsupervised_dataset_example(next(iter(dataset)))
elif stage == "sft": elif stage == "sft":
print_supervised_dataset_example(dataset[0]) print_supervised_dataset_example(next(iter(dataset)))
elif stage == "rm": elif stage == "rm":
print_pairwise_dataset_example(dataset[0]) print_pairwise_dataset_example(next(iter(dataset)))
elif stage == "ppo": elif stage == "ppo":
print_unsupervised_dataset_example(dataset[0]) print_unsupervised_dataset_example(next(iter(dataset)))
return dataset return dataset

View File

@@ -0,0 +1,15 @@
from typing import TYPE_CHECKING, Dict
if TYPE_CHECKING:
from datasets import Dataset
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]:
if do_train:
if dev_ratio > 1e-6: # Split the dataset
dataset = dataset.train_test_split(test_size=dev_ratio)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}

View File

@@ -1,16 +1,13 @@
import os import os
import json import json
import time import time
from typing import TYPE_CHECKING
from datetime import timedelta from datetime import timedelta
from transformers import ( from transformers import TrainerCallback
TrainerCallback,
TrainerControl, if TYPE_CHECKING:
TrainerState, from transformers import TrainingArguments, TrainerState, TrainerControl
TrainingArguments
)
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
@@ -20,13 +17,13 @@ class LogCallback(TrainerCallback):
self.start_time = time.time() self.start_time = time.time()
self.tracker = {} self.tracker = {}
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of training. Event called at the beginning of training.
""" """
self.start_time = time.time() self.start_time = time.time()
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of a training step. If using gradient accumulation, one training step Event called at the beginning of a training step. If using gradient accumulation, one training step
might take several inputs. might take several inputs.
@@ -35,7 +32,7 @@ class LogCallback(TrainerCallback):
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of an substep during gradient accumulation. Event called at the end of an substep during gradient accumulation.
""" """
@@ -43,7 +40,7 @@ class LogCallback(TrainerCallback):
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r""" r"""
Event called after logging the last logs. Event called after logging the last logs.
""" """

View File

@@ -13,6 +13,12 @@ SUPPORTED_MODELS = {
"LLaMA-13B": "huggyllama/llama-13b", "LLaMA-13B": "huggyllama/llama-13b",
"LLaMA-30B": "huggyllama/llama-30b", "LLaMA-30B": "huggyllama/llama-30b",
"LLaMA-65B": "huggyllama/llama-65b", "LLaMA-65B": "huggyllama/llama-65b",
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
"BLOOM-560M": "bigscience/bloom-560m", "BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b", "BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1", "BLOOM-7B1": "bigscience/bloom-7b1",
@@ -30,8 +36,9 @@ SUPPORTED_MODELS = {
"InternLM-7B-Chat": "internlm/internlm-chat-7b" "InternLM-7B-Chat": "internlm/internlm-chat-7b"
} }
DEFAULT_MODULE = { # will be deprecated DEFAULT_MODULE = {
"LLaMA": "q_proj,v_proj", "LLaMA": "q_proj,v_proj",
"LLaMA2": "q_proj,v_proj",
"BLOOM": "query_key_value", "BLOOM": "query_key_value",
"BLOOMZ": "query_key_value", "BLOOMZ": "query_key_value",
"Falcon": "query_key_value", "Falcon": "query_key_value",

View File

@@ -16,8 +16,16 @@ class LoggerHandler(logging.Handler):
self.log += "\n\n" self.log += "\n\n"
def get_logger(name: str) -> logging.Logger: def reset_logging():
r"""
Removes basic config of root logger
"""
root = logging.getLogger()
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters))
def get_logger(name: str) -> logging.Logger:
formatter = logging.Formatter( formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S" datefmt="%m/%d/%Y %H:%M:%S"

View File

@@ -1,12 +1,14 @@
import torch import torch
from typing import List, Optional from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor from transformers.generation.logits_process import LogitsProcessor
from llmtuner.extras.constants import LAYERNORM_NAMES from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
class AverageMeter: class AverageMeter:
r""" r"""
@@ -28,7 +30,7 @@ class AverageMeter:
self.avg = self.sum / self.count self.avg = self.sum / self.count
# Avoid runtime error in model.generate(do_sample=True). # Avoids runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor): class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
@@ -44,29 +46,37 @@ def get_logits_processor() -> LogitsProcessorList:
return logits_processor return logits_processor
def print_trainable_params(model: torch.nn.Module) -> None: def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
r"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params, all_param = 0, 0 trainable_params, all_param = 0, 0
for param in model.parameters(): for param in model.parameters():
num_params = param.numel() num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty # if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"): if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
if param.__class__.__name__ == "Params4bit":
num_params = num_params * 2
all_param += num_params all_param += num_params
if param.requires_grad: if param.requires_grad:
trainable_params += num_params trainable_params += num_params
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param)) return trainable_params, all_param
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32 # Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35 # Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
def prepare_model_for_training( def prepare_model_for_training(
model: PreTrainedModel, model: "PreTrainedModel",
finetuning_type: str, finetuning_type: str,
output_embedding_layer_name: Optional[str] = "lm_head", output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True, use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> PreTrainedModel: ) -> "PreTrainedModel":
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
@@ -83,19 +93,23 @@ def prepare_model_for_training(
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name): if finetuning_type != "full" and hasattr(model, output_layer_name):
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name) if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"):
input_dtype = output_embedding_layer.weight.dtype model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728)
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
input_dtype = output_layer.weight.dtype
class CastOutputToFloat(torch.nn.Sequential): class CastOutputToFloat(torch.nn.Sequential):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.to(input_dtype)).to(torch.float32) return super().forward(x.to(input_dtype)).to(torch.float32)
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) setattr(model, output_layer_name, CastOutputToFloat(output_layer))
return model return model
def torch_gc() -> None: def torch_gc() -> None:
r""" r"""
Collects GPU memory. Collects GPU memory.

View File

@@ -12,8 +12,8 @@ from llmtuner.extras.logging import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
state_dict = model.state_dict() state_dict: Dict[str, torch.Tensor] = model.state_dict()
filtered_state_dict = {} filtered_state_dict = {}
for k, v in model.named_parameters(): for k, v in model.named_parameters():

View File

@@ -3,66 +3,61 @@ from dataclasses import dataclass
@dataclass @dataclass
class Format: class Template:
prefix: str prefix: str
prompt: str prompt: str
sep: str sep: str
use_history: bool use_history: bool
templates: Dict[str, Format] = {}
@dataclass
class Template:
name: str
def __post_init__(self):
if self.name in templates:
self.prefix = templates[self.name].prefix
self.prompt = templates[self.name].prompt
self.sep = templates[self.name].sep
self.use_history = templates[self.name].use_history
else:
raise ValueError("Template {} does not exist.".format(self.name))
def get_prompt( def get_prompt(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = "",
eos_token: Optional[str] = "</s>"
) -> str: ) -> str:
r""" r"""
Returns a string containing prompt without response. Returns a string containing prompt without response.
""" """
return "".join(self._format_example(query, history, prefix)) return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix)))
def get_dialog( def get_dialog(
self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" self,
) -> List[str]: query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
r""" r"""
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response. Returns a list containing prompt-response pairs.
""" """
return self._format_example(query, history, prefix) + [resp] result = self._format_example(query, history, prefix)
result[-1][-1] = resp
return result
def _format_example( def _format_example(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = "" self,
) -> List[str]: query: str,
prefix = prefix if prefix else self.prefix # use prefix if provided history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
prefix = prefix or self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
history = history if (history and self.use_history) else [] history = history if (history and self.use_history) else []
history = history + [(query, "<dummy>")] history = history + [(query, "")]
convs = [] convs = [
for turn_idx, (user_query, bot_resp) in enumerate(history): [(self.sep if turn_idx else prefix) + self.prompt.format(query=query_i), resp_i]
if turn_idx == 0: for turn_idx, (query_i, resp_i) in enumerate(history)
convs.append(prefix + self.prompt.format(query=user_query)) ]
convs.append(bot_resp) return convs
else:
convs.append(self.sep + self.prompt.format(query=user_query))
convs.append(bot_resp) templates: Dict[str, Template] = {}
return convs[:-1] # drop last
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None: def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
templates[name] = Format( templates[name] = Template(
prefix=prefix, prefix=prefix,
prompt=prompt, prompt=prompt,
sep=sep, sep=sep,
@@ -70,6 +65,12 @@ def register_template(name: str, prefix: str, prompt: str, sep: str, use_history
) )
def get_template(name: str) -> Template:
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
return template
r""" r"""
Supports language model inference without histories. Supports language model inference without histories.
""" """
@@ -95,6 +96,27 @@ register_template(
) )
r"""
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
"""
register_template(
name="llama2",
prefix="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
prompt=" [INST] {query} [/INST] ",
sep="",
use_history=True
)
r""" r"""
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
https://github.com/ymcui/Chinese-LLaMA-Alpaca https://github.com/ymcui/Chinese-LLaMA-Alpaca
@@ -118,7 +140,7 @@ register_template(
prefix="A chat between a curious user and an artificial intelligence assistant. " prefix="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.", "The assistant gives helpful, detailed, and polite answers to the user's questions.",
prompt="USER: {query} ASSISTANT: ", prompt="USER: {query} ASSISTANT: ",
sep="</s>", sep="",
use_history=True use_history=True
) )
@@ -202,8 +224,8 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
register_template( register_template(
name="baichuan", name="baichuan",
prefix="", prefix="",
prompt=" <reserved_102> {query} <reserved_103> ", prompt="<reserved_102>{query}<reserved_103>",
sep="</s>", sep="",
use_history=True use_history=True
) )

View File

@@ -1,6 +1,6 @@
import os import os
import json import json
from typing import List, Optional from typing import List, Literal, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -16,10 +16,10 @@ class DatasetAttr:
return self.dataset_name return self.dataset_name
def __post_init__(self): def __post_init__(self):
self.prompt_column = "instruction" self.prompt = "instruction"
self.query_column = "input" self.query = "input"
self.response_column = "output" self.response = "output"
self.history_column = None self.history = None
@dataclass @dataclass
@@ -27,8 +27,11 @@ class DataArguments:
""" """
Arguments pertaining to what data we are going to input our model for training and evaluation. Arguments pertaining to what data we are going to input our model for training and evaluation.
""" """
template: str = field(
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
dataset: Optional[str] = field( dataset: Optional[str] = field(
default="alpaca_zh", default="alpaca_en",
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
) )
dataset_dir: Optional[str] = field( dataset_dir: Optional[str] = field(
@@ -39,6 +42,18 @@ class DataArguments:
default="train", default="train",
metadata={"help": "Which dataset split to use for training and evaluation."} metadata={"help": "Which dataset split to use for training and evaluation."}
) )
streaming: Optional[bool] = field(
default=False,
metadata={"help": "Enable streaming mode."}
)
buffer_size: Optional[int] = field(
default=16384,
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
)
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
default="concat",
metadata={"help": "Strategy to use in dataset mixing."}
)
overwrite_cache: Optional[bool] = field( overwrite_cache: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."} metadata={"help": "Overwrite the cached training and evaluation sets."}
@@ -75,10 +90,6 @@ class DataArguments:
default=0, default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
) )
prompt_template: Optional[str] = field(
default="default",
metadata={"help": "Which template to use for constructing prompts in training and inference."}
)
def init_for_training(self): # support mixing multiple datasets def init_for_training(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")] dataset_names = [ds.strip() for ds in self.dataset.split(",")]
@@ -111,9 +122,9 @@ class DataArguments:
dataset_attr.source_prefix = prefix_list[i] dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) dataset_attr.history = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr) self.dataset_list.append(dataset_attr)

View File

@@ -16,9 +16,10 @@ class FinetuningArguments:
default=32, default=32,
metadata={"help": "Number of decoder blocks in the model. \ metadata={"help": "Number of decoder blocks in the model. \
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \ LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
BLOOM choices: [\"24\", \"30\", \"70\"], \ BLOOM choices: [\"24\", \"30\", \"70\"], \
Falcon choices: [\"32\", \"60\"], \ Falcon choices: [\"32\", \"60\"], \
Baichuan choices: [\"32\"]"} Baichuan choices: [\"32\", \"40\"]"}
) )
num_layer_trainable: Optional[int] = field( num_layer_trainable: Optional[int] = field(
default=3, default=3,
@@ -27,7 +28,7 @@ class FinetuningArguments:
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
default="mlp", default="mlp",
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
LLaMA choices: [\"mlp\", \"self_attn\"], \ LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \ BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
Baichuan choices: [\"mlp\", \"self_attn\"]"} Baichuan choices: [\"mlp\", \"self_attn\"]"}
) )
@@ -46,7 +47,7 @@ class FinetuningArguments:
lora_target: Optional[str] = field( lora_target: Optional[str] = field(
default="q_proj,v_proj", default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"} Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
) )

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class GeneralArguments: class GeneralArguments:
""" """
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which stage we are going to perform.
""" """
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field( stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
default="sft", default="sft",

View File

@@ -1,7 +1,7 @@
import os import os
import torch import torch
from typing import TYPE_CHECKING
from transformers.modeling_utils import PreTrainedModel
from peft import ( from peft import (
PeftModel, PeftModel,
TaskType, TaskType,
@@ -12,19 +12,22 @@ from peft.utils import CONFIG_NAME, WEIGHTS_NAME
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import load_trainable_params from llmtuner.extras.save_and_load import load_trainable_params
from llmtuner.hparams import ModelArguments, FinetuningArguments
if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
from llmtuner.hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)
def init_adapter( def init_adapter(
model: PreTrainedModel, model: "PreTrainedModel",
model_args: ModelArguments, model_args: "ModelArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
is_trainable: bool, is_trainable: bool,
is_mergeable: bool is_mergeable: bool
) -> PreTrainedModel: ) -> "PreTrainedModel":
r""" r"""
Initializes the adapters. Initializes the adapters.

View File

@@ -1,6 +1,6 @@
import os import os
import torch import torch
from typing import Literal, Optional, Tuple from typing import TYPE_CHECKING, Literal, Optional, Tuple
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
@@ -10,30 +10,34 @@ from transformers import (
) )
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizerBase from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params from llmtuner.extras.misc import count_parameters, prepare_model_for_training
from llmtuner.extras.save_and_load import load_valuehead_params from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import ModelArguments, FinetuningArguments from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter from llmtuner.tuner.core.adapter import init_adapter
if TYPE_CHECKING:
from llmtuner.hparams import ModelArguments
logger = get_logger(__name__) logger = get_logger(__name__)
check_min_version("4.29.1") check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7") require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
def load_model_and_tokenizer( def load_model_and_tokenizer(
model_args: ModelArguments, model_args: "ModelArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False, is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
@@ -80,9 +84,6 @@ def load_model_and_tokenizer(
elif model_args.quantization_bit == 4: elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
config_kwargs["load_in_4bit"] = True config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig( config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
@@ -92,11 +93,13 @@ def load_model_and_tokenizer(
) )
is_mergeable = False is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if not is_trainable: # `device_map=auto` should be used for inference only if (
config_kwargs["device_map"] = "auto" model_args.quantization_bit is not None
or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled())
):
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
model_to_load = model_args.checkpoint_dir[0] model_to_load = model_args.checkpoint_dir[0]
@@ -108,7 +111,7 @@ def load_model_and_tokenizer(
model_to_load, model_to_load,
config=config, config=config,
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16, torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
low_cpu_mem_usage=True, low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs **config_kwargs
) )
@@ -126,6 +129,7 @@ def load_model_and_tokenizer(
if stage == "rm" or stage == "ppo": # add value head if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
@@ -146,6 +150,9 @@ def load_model_and_tokenizer(
model.requires_grad_(False) # fix all model params model.requires_grad_(False) # fix all model params
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
print_trainable_params(model) trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
))
return model, tokenizer return model, tokenizer

View File

@@ -19,20 +19,39 @@ from llmtuner.hparams import (
logger = get_logger(__name__) logger = get_logger(__name__)
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
if args is not None:
return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
return parser.parse_args_into_dataclasses()
def parse_train_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
parser = HfArgumentParser((
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments
))
return _parse_args(parser, args)
def parse_infer_args(
args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
parser = HfArgumentParser((
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
))
return _parse_args(parser, args)
def get_train_args( def get_train_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: ) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args)
parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
if args is not None:
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:
@@ -54,12 +73,21 @@ def get_train_args(
assert not (training_args.do_train and training_args.predict_with_generate), \ assert not (training_args.do_train and training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True while training." "`predict_with_generate` cannot be set as True while training."
assert (not training_args.do_predict) or training_args.predict_with_generate, \ assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
"Please enable `predict_with_generate` to save model predictions." "Please enable `predict_with_generate` to save model predictions."
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method." "Quantization is only compatible with the LoRA method."
assert not (training_args.max_steps == -1 and data_args.streaming), \
"Please specify `max_steps` in streaming mode."
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
"Streaming mode does not support evaluation currently."
assert not (general_args.stage == "ppo" and data_args.streaming), \
"Streaming mode does not suppport PPO training currently."
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
@@ -73,13 +101,22 @@ def get_train_args(
if training_args.do_train and (not training_args.fp16): if training_args.do_train and (not training_args.fp16):
logger.warning("We recommend enable fp16 mixed precision training.") logger.warning("We recommend enable fp16 mixed precision training.")
if data_args.prompt_template == "default": if (
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.") training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None: and finetuning_args.finetuning_type == "lora"
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.") ):
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False training_args.ddp_find_unused_parameters = False
if data_args.max_samples is not None and data_args.streaming:
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
data_args.max_samples = None
if data_args.dev_ratio > 1e-6 and data_args.streaming:
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
data_args.dev_ratio = 0
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
@@ -91,10 +128,10 @@ def get_train_args(
model_args.compute_dtype = torch.float32 model_args.compute_dtype = torch.float32
# Log on each process the small summary: # Log on each process the small summary:
logger.info( logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" training_args.local_rank, training_args.device, training_args.n_gpu,
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" bool(training_args.local_rank != -1), training_args.fp16
) ))
logger.info(f"Training/evaluation parameters {training_args}") logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model. # Set seed before initializing model.
@@ -106,17 +143,7 @@ def get_train_args(
def get_infer_args( def get_infer_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: ) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
if args is not None:
model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
else:
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method." "Quantization is only compatible with the LoRA method."
@@ -128,7 +155,4 @@ def get_infer_args(
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
"Quantized model only accepts a single checkpoint." "Quantized model only accepts a single checkpoint."
if data_args.prompt_template == "default":
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args

View File

@@ -1,15 +1,19 @@
import os import os
import torch import torch
from typing import Dict, Optional from typing import TYPE_CHECKING, Dict, Optional
from transformers import Seq2SeqTrainer from transformers import Seq2SeqTrainer
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
from transformers.modeling_utils import unwrap_model from transformers.modeling_utils import PreTrainedModel, unwrap_model
from peft import PeftModel
from trl import PreTrainedModelWrapper
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
from llmtuner.hparams import FinetuningArguments
if TYPE_CHECKING:
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -20,7 +24,7 @@ class PeftTrainer(Seq2SeqTrainer):
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
""" """
def __init__(self, finetuning_args: FinetuningArguments, **kwargs): def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self._remove_log() self._remove_log()
@@ -41,29 +45,34 @@ class PeftTrainer(Seq2SeqTrainer):
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}") logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model) model = unwrap_model(self.model)
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only) if isinstance(model, PreTrainedModelWrapper):
backbone_model = getattr(model, "pretrained_model") # Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) model_state_dict = state_dict or model.state_dict()
else: v_head_state_dict = {
backbone_model = model name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
for name in model_state_dict.keys() if name.startswith("v_head.")
}
if self.finetuning_args.finetuning_type == "lora": torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) model = model.pretrained_model
else: # freeze/full tuning
backbone_model.config.use_cache = True state_dict = state_dict or get_state_dict(model)
backbone_model.save_pretrained( if isinstance(model, (PeftModel, PreTrainedModel)):
output_dir, model.config.use_cache = True
state_dict=get_state_dict(backbone_model), model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
safe_serialization=self.args.save_safetensors model.config.use_cache = False
) else:
backbone_model.config.use_cache = False torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n") f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME)) self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self): def _load_best_model(self):
@@ -73,16 +82,15 @@ class PeftTrainer(Seq2SeqTrainer):
Subclass and override to inject custom behavior. It should not be directly used by external scripts. Subclass and override to inject custom behavior. It should not be directly used by external scripts.
""" """
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
model = unwrap_model(self.model) model = unwrap_model(self.model)
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
if self.finetuning_args.finetuning_type == "lora": if isinstance(model, PreTrainedModelWrapper):
backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter")) model.v_head.load_state_dict(torch.load(
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint): os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
model.v_head.load_state_dict({ ))
"summary.weight": getattr(model, "reward_head_weight"), model = model.pretrained_model
"summary.bias": getattr(model, "reward_head_bias")
}) if isinstance(model, PeftModel):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else: # freeze/full-tuning else: # freeze/full-tuning
load_trainable_params(backbone_model, self.state.best_model_checkpoint) load_trainable_params(model, self.state.best_model_checkpoint)

View File

@@ -2,21 +2,25 @@ import os
import math import math
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from typing import Callable, Dict, List, Optional from typing import TYPE_CHECKING, Callable, Dict, List, Optional
from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl from transformers import TrainerState, TrainerControl
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from trl import PPOTrainer from trl import PPOTrainer
from trl.core import LengthSampler from trl.core import LengthSampler
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, get_logits_processor from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from llmtuner.extras.callbacks import LogCallback
from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -25,11 +29,12 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r""" r"""
Inherits PPOTrainer. Inherits PPOTrainer.
""" """
def __init__( def __init__(
self, self,
training_args: Seq2SeqTrainingArguments, training_args: "Seq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: List[LogCallback], callbacks: List["LogCallback"],
**kwargs **kwargs
): ):
PPOTrainer.__init__(self, **kwargs) PPOTrainer.__init__(self, **kwargs)
@@ -66,7 +71,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}") logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = { gen_kwargs = {
@@ -107,7 +112,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Compute rewards # Compute rewards
replace_model(unwrapped_model, target="reward") replace_model(unwrapped_model, target="reward")
with torch.no_grad(): with torch.no_grad():
_, _, values = self.model(**self.prepare_model_inputs(queries, responses)) _, _, values = self.model(
**self.prepare_model_inputs(queries, responses),
output_hidden_states=True,
return_dict=True
)
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")

View File

@@ -1,11 +1,13 @@
import torch import torch
from typing import Dict, List, Literal, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.constants import LAYERNORM_NAMES from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily if target == "reward": # save default head temporarily
valuehead_state_dict = model.v_head.state_dict() valuehead_state_dict = model.v_head.state_dict()
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"]) setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
@@ -19,10 +21,10 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
def cast_layernorm_dtype( def cast_layernorm_dtype(
model: AutoModelForCausalLMWithValueHead, model: "AutoModelForCausalLMWithValueHead",
layer_norm_names: List[str] = LAYERNORM_NAMES, layer_norm_names: List[str] = LAYERNORM_NAMES,
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: ) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
layer_norm_state_dict = {} layer_norm_state_dict = {}

View File

@@ -2,26 +2,30 @@
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py # https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
import math import math
from typing import TYPE_CHECKING
from trl import PPOConfig from trl import PPOConfig
from torch.optim import AdamW from torch.optim import AdamW
from typing import Optional, List from typing import Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback from transformers import DataCollatorForSeq2Seq
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_ppo( def run_ppo(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: Seq2SeqTrainingArguments, training_args: "Seq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: Optional[List[TrainerCallback]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")

View File

@@ -1,24 +1,27 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
import math import math
from typing import Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_pt( def run_pt(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: Seq2SeqTrainingArguments, training_args: "Seq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: Optional[List[TrainerCallback]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
@@ -28,16 +31,6 @@ def run_pt(
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
) )
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer # Initialize our Trainer
trainer = PeftTrainer( trainer = PeftTrainer(
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
@@ -46,7 +39,7 @@ def run_pt(
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
**trainer_kwargs **split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
) )
# Training # Training

View File

@@ -15,5 +15,8 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
We generate 2 * n examples where the first n examples represent chosen examples and We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples. the last n examples represent rejected examples.
""" """
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features] features = [
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])}
for key in ("accept_ids", "reject_ids") for feature in features
]
return super().__call__(features) return super().__call__(features)

View File

@@ -1,9 +1,18 @@
import os
import json
import torch import torch
from typing import Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from transformers.modeling_utils import PreTrainedModel
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
from transformers.modeling_utils import PreTrainedModel
logger = get_logger(__name__)
class PairwisePeftTrainer(PeftTrainer): class PairwisePeftTrainer(PeftTrainer):
r""" r"""
@@ -16,7 +25,7 @@ class PairwisePeftTrainer(PeftTrainer):
def compute_loss( def compute_loss(
self, self,
model: PreTrainedModel, model: "PreTrainedModel",
inputs: Dict[str, torch.Tensor], inputs: Dict[str, torch.Tensor],
return_outputs: Optional[bool] = False return_outputs: Optional[bool] = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
@@ -32,7 +41,30 @@ class PairwisePeftTrainer(PeftTrainer):
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
""" """
batch_size = inputs["input_ids"].size(0) // 2 batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs) _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0) r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
def save_predictions(
self,
predict_results: "PredictionOutput"
) -> None:
r"""
Saves model predictions to `output_dir`.
A custom behavior that not contained in Seq2SeqTrainer.
"""
if not self.is_world_process_zero():
return
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}")
acc_scores, rej_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for acc_score, rej_score in zip(acc_scores, rej_scores):
res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)}))
writer.write("\n".join(res))

View File

@@ -2,25 +2,27 @@
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy from llmtuner.tuner.rm.metric import compute_accuracy
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_rm( def run_rm(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: Seq2SeqTrainingArguments, training_args: "Seq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: Optional[List[TrainerCallback]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
@@ -29,16 +31,6 @@ def run_rm(
training_args.remove_unused_columns = False # important for pairwise dataset training_args.remove_unused_columns = False # important for pairwise dataset
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer # Initialize our Trainer
trainer = PairwisePeftTrainer( trainer = PairwisePeftTrainer(
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
@@ -48,7 +40,7 @@ def run_rm(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**trainer_kwargs **split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
) )
# Training # Training
@@ -66,3 +58,10 @@ def run_rm(
metrics = trainer.evaluate(metric_key_prefix="eval") metrics = trainer.evaluate(metric_key_prefix="eval")
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict")
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results)

View File

@@ -1,7 +1,6 @@
import numpy as np import numpy as np
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
from transformers.tokenization_utils import PreTrainedTokenizer
import jieba import jieba
from rouge_chinese import Rouge from rouge_chinese import Rouge
@@ -9,6 +8,9 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer
@dataclass @dataclass
class ComputeMetrics: class ComputeMetrics:
@@ -16,7 +18,7 @@ class ComputeMetrics:
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
""" """
tokenizer: PreTrainedTokenizer tokenizer: "PreTrainedTokenizer"
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
r""" r"""

View File

@@ -3,13 +3,15 @@ import json
import torch import torch
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from typing import Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from transformers.trainer import PredictionOutput
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
if TYPE_CHECKING:
from transformers.trainer import PredictionOutput
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -36,31 +38,44 @@ class Seq2SeqPeftTrainer(PeftTrainer):
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
if label_len > prompt_len: if label_len > prompt_len:
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"]) inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
if "attention_mask" in inputs:
inputs["attention_mask"] = self._pad_tensors_to_target_len(
inputs["attention_mask"], inputs["labels"], pad_token_id=0
)
if "position_ids" in inputs:
inputs["position_ids"] = self._pad_tensors_to_target_len(
inputs["position_ids"], inputs["labels"], pad_token_id=0
)
loss, generated_tokens, labels = super().prediction_step( loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
) )
generated_tokens = generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None generated_tokens = (
generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
)
return (loss, generated_tokens, labels) return (loss, generated_tokens, labels)
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor: def _pad_tensors_to_target_len(
self,
src_tensor: torch.Tensor,
tgt_tensor: torch.Tensor,
pad_token_id: Optional[int] = None
) -> torch.Tensor:
r""" r"""
Pads the tensor to the same length as the target tensor. Pads the tensor to the same length as the target tensor.
Should only be called when predict_with_generate=True. Should only be called when predict_with_generate=True.
""" """
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): if pad_token_id is None:
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
# If PAD token is not defined at least EOS token has to be defined assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
pad_token_id = ( pad_token_id = self.tokenizer.pad_token_id
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
else:
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else: else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model.")
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor) padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
@@ -68,7 +83,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
def save_predictions( def save_predictions(
self, self,
predict_results: PredictionOutput predict_results: "PredictionOutput"
) -> None: ) -> None:
r""" r"""
Saves model predictions to `output_dir`. Saves model predictions to `output_dir`.

View File

@@ -1,25 +1,28 @@
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
from typing import Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_sft( def run_sft(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: Seq2SeqTrainingArguments, training_args: "Seq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: Optional[List[TrainerCallback]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
@@ -35,16 +38,6 @@ def run_sft(
training_args.generation_num_beams = data_args.eval_num_beams if \ training_args.generation_num_beams = data_args.eval_num_beams if \
data_args.eval_num_beams is not None else training_args.generation_num_beams data_args.eval_num_beams is not None else training_args.generation_num_beams
# Split the dataset
if training_args.do_train:
if data_args.dev_ratio > 1e-6:
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
trainer_kwargs = {"train_dataset": dataset}
else: # do_eval or do_predict
trainer_kwargs = {"eval_dataset": dataset}
# Initialize our Trainer # Initialize our Trainer
trainer = Seq2SeqPeftTrainer( trainer = Seq2SeqPeftTrainer(
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
@@ -54,7 +47,7 @@ def run_sft(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**trainer_kwargs **split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
) )
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`

View File

@@ -1 +0,0 @@
from llmtuner.webui.interface import create_ui

View File

@@ -54,7 +54,7 @@ class WebChatModel(ChatModel):
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None, quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template, template=template,
source_prefix=source_prefix source_prefix=source_prefix
) )
super().__init__(*get_infer_args(args)) super().__init__(*get_infer_args(args))
@@ -73,6 +73,7 @@ class WebChatModel(ChatModel):
chatbot: List[Tuple[str, str]], chatbot: List[Tuple[str, str]],
query: str, query: str,
history: List[Tuple[str, str]], history: List[Tuple[str, str]],
prefix: str,
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float temperature: float
@@ -80,9 +81,17 @@ class WebChatModel(ChatModel):
chatbot.append([query, ""]) chatbot.append([query, ""])
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
query, history, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
): ):
response += new_text response += new_text
response = self.postprocess(response)
new_history = history + [(query, response)] new_history = history + [(query, response)]
chatbot[-1] = [query, response] chatbot[-1] = [query, response]
yield chatbot, new_history yield chatbot, new_history
def postprocess(self, response: str) -> str:
blocks = response.split("```")
for i, block in enumerate(blocks):
if i % 2 == 0:
blocks[i] = block.replace("<", "&lt;").replace(">", "&gt;")
return "```".join(blocks)

View File

@@ -1,4 +1,5 @@
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.top import create_top from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.sft import create_sft_tab from llmtuner.webui.components.sft import create_sft_tab
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab

View File

@@ -1,42 +1,37 @@
from typing import Dict, Optional, Tuple from typing import TYPE_CHECKING, Dict, Optional, Tuple
import gradio as gr import gradio as gr
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel
def create_chat_box( def create_chat_box(
chat_model: WebChatModel, chat_model: "WebChatModel",
visible: Optional[bool] = False visible: Optional[bool] = False
) -> Tuple[Block, Component, Component, Dict[str, Component]]: ) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box: with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot() chatbot = gr.Chatbot()
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
with gr.Column(scale=12): prefix = gr.Textbox(show_label=False)
query = gr.Textbox(show_label=False, lines=8) query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")
with gr.Column(min_width=32, scale=1):
submit_btn = gr.Button(variant="primary")
with gr.Column(scale=1): with gr.Column(scale=1):
clear_btn = gr.Button() clear_btn = gr.Button()
max_new_tokens = gr.Slider( max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1)
10, 2048, value=chat_model.generating_args.max_new_tokens, step=1, interactive=True top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01)
) temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01)
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01, interactive=True)
temperature = gr.Slider(
0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01, interactive=True
)
history = gr.State([]) history = gr.State([])
submit_btn.click( submit_btn.click(
chat_model.predict, chat_model.predict,
[chatbot, query, history, max_new_tokens, top_p, temperature], [chatbot, query, history, prefix, max_new_tokens, top_p, temperature],
[chatbot, history], [chatbot, history],
show_progress=True show_progress=True
).then( ).then(
@@ -46,6 +41,7 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict( return chat_box, chatbot, history, dict(
prefix=prefix,
query=query, query=query,
submit_btn=submit_btn, submit_btn=submit_btn,
clear_btn=clear_btn, clear_btn=clear_btn,

View File

@@ -1,10 +1,12 @@
import gradio as gr import gradio as gr
from gradio.blocks import Block from typing import TYPE_CHECKING, Tuple
from gradio.components import Component
from typing import Tuple if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
def create_preview_box() -> Tuple[Block, Component, Component, Component]: def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
with gr.Box(visible=False, elem_classes="modal-box") as preview_box: with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row(): with gr.Row():
preview_count = gr.Number(interactive=False) preview_count = gr.Number(interactive=False)

View File

@@ -1,14 +1,16 @@
from typing import Dict from typing import TYPE_CHECKING, Dict
import gradio as gr import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview from llmtuner.webui.utils import can_preview, get_preview
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.runner import Runner
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
@@ -31,7 +33,8 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
start_btn = gr.Button() start_btn = gr.Button()
stop_btn = gr.Button() stop_btn = gr.Button()
output_box = gr.Markdown() with gr.Box():
output_box = gr.Markdown()
start_btn.click( start_btn.click(
runner.run_eval, runner.run_eval,

View File

@@ -0,0 +1,36 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr
from llmtuner.webui.utils import export_model
if TYPE_CHECKING:
from gradio.components import Component
def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
with gr.Row():
save_dir = gr.Textbox()
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
export_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False)
export_btn.click(
export_model,
[
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
top_elems["finetuning_type"],
max_shard_size,
save_dir
],
[info_box]
)
return dict(
save_dir=save_dir,
max_shard_size=max_shard_size,
export_btn=export_btn,
info_box=info_box
)

View File

@@ -1,18 +1,20 @@
from typing import Dict from typing import TYPE_CHECKING, Dict
import gradio as gr import gradio as gr
from gradio.components import Component
from llmtuner.webui.chat import WebChatModel from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box from llmtuner.webui.components.chatbot import create_chat_box
if TYPE_CHECKING:
from gradio.components import Component
def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
load_btn = gr.Button() load_btn = gr.Button()
unload_btn = gr.Button() unload_btn = gr.Button()
info_box = gr.Markdown() info_box = gr.Textbox(show_label=False, interactive=False)
chat_model = WebChatModel() chat_model = WebChatModel()
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model) chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)

View File

@@ -1,16 +1,18 @@
from typing import Dict from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType from transformers.trainer_utils import SchedulerType
import gradio as gr import gradio as gr
from gradio.components import Component
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import can_preview, get_preview, gen_plot from llmtuner.webui.utils import can_preview, get_preview, gen_plot
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.runner import Runner
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
@@ -35,21 +37,32 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
lr_scheduler_type = gr.Dropdown( lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType] value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
) )
max_grad_norm = gr.Textbox(value="1.0")
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
fp16 = gr.Checkbox(value=True)
with gr.Row(): with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5) with gr.Row():
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=2)
with gr.Row(): with gr.Row():
start_btn = gr.Button() start_btn = gr.Button()
stop_btn = gr.Button() stop_btn = gr.Button()
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=3):
output_dir = gr.Textbox(interactive=True) output_dir = gr.Textbox()
output_box = gr.Markdown()
with gr.Box():
output_box = gr.Markdown()
with gr.Column(scale=1): with gr.Column(scale=1):
loss_viewer = gr.Plot() loss_viewer = gr.Plot()
@@ -74,10 +87,15 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
batch_size, batch_size,
gradient_accumulation_steps, gradient_accumulation_steps,
lr_scheduler_type, lr_scheduler_type,
max_grad_norm,
dev_ratio, dev_ratio,
fp16,
logging_steps, logging_steps,
save_steps, save_steps,
warmup_steps,
compute_type,
lora_rank,
lora_dropout,
lora_target,
output_dir output_dir
], ],
[output_box] [output_box]
@@ -103,10 +121,17 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
batch_size=batch_size, batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type, lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
dev_ratio=dev_ratio, dev_ratio=dev_ratio,
fp16=fp16, advanced_tab=advanced_tab,
logging_steps=logging_steps, logging_steps=logging_steps,
save_steps=save_steps, save_steps=save_steps,
warmup_steps=warmup_steps,
compute_type=compute_type,
lora_tab=lora_tab,
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target,
start_btn=start_btn, start_btn=start_btn,
stop_btn=stop_btn, stop_btn=stop_btn,
output_dir=output_dir, output_dir=output_dir,

View File

@@ -1,15 +1,17 @@
from typing import Dict from typing import TYPE_CHECKING, Dict
import gradio as gr import gradio as gr
from gradio.components import Component
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
from llmtuner.webui.utils import can_quantize from llmtuner.webui.utils import can_quantize
if TYPE_CHECKING:
from gradio.components import Component
def create_top() -> Dict[str, Component]:
def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row(): with gr.Row():
@@ -22,10 +24,11 @@ def create_top() -> Dict[str, Component]:
checkpoints = gr.Dropdown(multiselect=True, scale=5) checkpoints = gr.Dropdown(multiselect=True, scale=5)
refresh_btn = gr.Button(scale=1) refresh_btn = gr.Button(scale=1)
with gr.Row(): with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
quantization_bit = gr.Dropdown([8, 4], scale=1) with gr.Row():
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=2) quantization_bit = gr.Dropdown([8, 4], scale=1)
source_prefix = gr.Textbox(scale=4) template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1)
source_prefix = gr.Textbox(scale=2)
model_name.change( model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints] list_checkpoint, [model_name, finetuning_type], [checkpoints]
@@ -47,9 +50,10 @@ def create_top() -> Dict[str, Component]:
model_name=model_name, model_name=model_name,
model_path=model_path, model_path=model_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
template=template,
checkpoints=checkpoints, checkpoints=checkpoints,
refresh_btn=refresh_btn, refresh_btn=refresh_btn,
advanced_tab=advanced_tab,
quantization_bit=quantization_bit, quantization_bit=quantization_bit,
template=template,
source_prefix=source_prefix source_prefix=source_prefix
) )

View File

@@ -5,7 +5,8 @@ from llmtuner.webui.components import (
create_top, create_top,
create_sft_tab, create_sft_tab,
create_eval_tab, create_eval_tab,
create_infer_tab create_infer_tab,
create_export_tab
) )
from llmtuner.webui.css import CSS from llmtuner.webui.css import CSS
from llmtuner.webui.manager import Manager from llmtuner.webui.manager import Manager
@@ -27,10 +28,13 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Evaluate"): with gr.Tab("Evaluate"):
eval_elems = create_eval_tab(top_elems, runner) eval_elems = create_eval_tab(top_elems, runner)
with gr.Tab("Inference"): with gr.Tab("Chat"):
infer_elems = create_infer_tab(top_elems) infer_elems = create_infer_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems] with gr.Tab("Export"):
export_elems = create_export_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
manager = Manager(elem_list) manager = Manager(elem_list)
demo.load( demo.load(
@@ -51,4 +55,4 @@ def create_ui() -> gr.Blocks:
if __name__ == "__main__": if __name__ == "__main__":
demo = create_ui() demo = create_ui()
demo.queue() demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True) demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)

View File

@@ -49,6 +49,14 @@ LOCALES = {
"value": "刷新断点" "value": "刷新断点"
} }
}, },
"advanced_tab": {
"en": {
"label": "Advanced configurations"
},
"zh": {
"label": "高级设置"
}
},
"quantization_bit": { "quantization_bit": {
"en": { "en": {
"label": "Quantization bit (optional)", "label": "Quantization bit (optional)",
@@ -71,12 +79,12 @@ LOCALES = {
}, },
"source_prefix": { "source_prefix": {
"en": { "en": {
"label": "Source prefix (optional)", "label": "System prompt (optional)",
"info": "A sequence used as the prefix of each samples." "info": "A sequence used as the default system prompt."
}, },
"zh": { "zh": {
"label": "前缀序列(非必填)", "label": "系统提示词(非必填)",
"info": "作为每个输入样本前缀的序列" "info": "默认使用的系统提示词"
} }
}, },
"dataset_dir": { "dataset_dir": {
@@ -209,6 +217,16 @@ LOCALES = {
"info": "采用的学习率调节器名称。" "info": "采用的学习率调节器名称。"
} }
}, },
"max_grad_norm": {
"en": {
"label": "Maximum gradient norm",
"info": "Norm for gradient clipping.."
},
"zh": {
"label": "最大梯度范数",
"info": "用于梯度裁剪的范数。"
}
},
"dev_ratio": { "dev_ratio": {
"en": { "en": {
"label": "Dev ratio", "label": "Dev ratio",
@@ -219,20 +237,10 @@ LOCALES = {
"info": "验证集占全部样本的百分比。" "info": "验证集占全部样本的百分比。"
} }
}, },
"fp16": {
"en": {
"label": "fp16",
"info": "Whether to use fp16 mixed precision training."
},
"zh": {
"label": "fp16",
"info": "是否启用 FP16 混合精度训练。"
}
},
"logging_steps": { "logging_steps": {
"en": { "en": {
"label": "Logging steps", "label": "Logging steps",
"info": "Number of update steps between two logs." "info": "Number of steps between two logs."
}, },
"zh": { "zh": {
"label": "日志间隔", "label": "日志间隔",
@@ -242,13 +250,71 @@ LOCALES = {
"save_steps": { "save_steps": {
"en": { "en": {
"label": "Save steps", "label": "Save steps",
"info": "Number of updates steps between two checkpoints." "info": "Number of steps between two checkpoints."
}, },
"zh": { "zh": {
"label": "保存间隔", "label": "保存间隔",
"info": "每两次断点保存间的更新步数。" "info": "每两次断点保存间的更新步数。"
} }
}, },
"warmup_steps": {
"en": {
"label": "Warmup steps",
"info": "Number of steps used for warmup."
},
"zh": {
"label": "预热步数",
"info": "学习率预热采用的步数。"
}
},
"compute_type": {
"en": {
"label": "Compute type",
"info": "Whether to use fp16 or bf16 mixed precision training."
},
"zh": {
"label": "计算类型",
"info": "是否启用 FP16 或 BF16 混合精度训练。"
}
},
"lora_tab": {
"en": {
"label": "LoRA configurations"
},
"zh": {
"label": "LoRA 参数设置"
}
},
"lora_rank": {
"en": {
"label": "LoRA rank",
"info": "The rank of LoRA matrices."
},
"zh": {
"label": "LoRA 秩",
"info": "LoRA 矩阵的秩。"
}
},
"lora_dropout": {
"en": {
"label": "LoRA Dropout",
"info": "Dropout ratio of LoRA weights."
},
"zh": {
"label": "LoRA 随机丢弃",
"info": "LoRA 权重随机丢弃的概率。"
}
},
"lora_target": {
"en": {
"label": "LoRA modules (optional)",
"info": "The name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
},
"zh": {
"label": "LoRA 作用层(非必填)",
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
}
},
"start_btn": { "start_btn": {
"en": { "en": {
"value": "Start" "value": "Start"
@@ -323,6 +389,14 @@ LOCALES = {
"value": "模型未加载,请先加载模型。" "value": "模型未加载,请先加载模型。"
} }
}, },
"prefix": {
"en": {
"placeholder": "System prompt (optional)"
},
"zh": {
"placeholder": "系统提示词(非必填)"
}
},
"query": { "query": {
"en": { "en": {
"placeholder": "Input..." "placeholder": "Input..."
@@ -378,6 +452,34 @@ LOCALES = {
"zh": { "zh": {
"label": "温度系数" "label": "温度系数"
} }
},
"save_dir": {
"en": {
"label": "Export dir",
"info": "Directory to save exported model."
},
"zh": {
"label": "导出目录",
"info": "保存导出模型的文件夹路径。"
}
},
"max_shard_size": {
"en": {
"label": "Max shard size (GB)",
"info": "The maximum size for a model file."
},
"zh": {
"label": "最大分块大小GB",
"info": "模型文件的最大大小。"
}
},
"export_btn": {
"en": {
"value": "Export"
},
"zh": {
"value": "开始导出"
}
} }
} }
@@ -403,6 +505,14 @@ ALERTS = {
"en": "Please choose a dataset.", "en": "Please choose a dataset.",
"zh": "请选择数据集。" "zh": "请选择数据集。"
}, },
"err_no_checkpoint": {
"en": "Please select a checkpoint.",
"zh": "请选择断点。"
},
"err_no_save_dir": {
"en": "Please provide export dir.",
"zh": "请填写导出目录"
},
"info_aborting": { "info_aborting": {
"en": "Aborted, wait for terminating...", "en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……" "zh": "训练中断,正在等待线程结束……"
@@ -430,5 +540,13 @@ ALERTS = {
"info_unloaded": { "info_unloaded": {
"en": "Model unloaded.", "en": "Model unloaded.",
"zh": "模型已卸载。" "zh": "模型已卸载。"
},
"info_exporting": {
"en": "Exporting model...",
"zh": "正在导出模型……"
},
"info_exported": {
"en": "Model exported.",
"zh": "模型导出完成。"
} }
} }

View File

@@ -1,6 +1,6 @@
import gradio as gr import gradio as gr
from typing import Any, Dict, List
from gradio.components import Component from gradio.components import Component
from typing import Any, Dict, List
from llmtuner.webui.common import get_model_path, list_dataset, load_config from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES from llmtuner.webui.locales import LOCALES
@@ -24,7 +24,7 @@ class Manager:
return refresh_dict return refresh_dict
def gen_label(self, lang: str) -> Dict[Component, dict]: def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
update_dict = {} update_dict = {}
refresh_dict = self.gen_refresh() refresh_dict = self.gen_refresh()

View File

@@ -3,10 +3,10 @@ import os
import threading import threading
import time import time
import transformers import transformers
from typing import List, Optional, Tuple from typing import Generator, List, Optional, Tuple
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated from llmtuner.extras.constants import DEFAULT_MODULE
from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import get_train_args, run_sft from llmtuner.tuner import get_train_args, run_sft
@@ -25,7 +25,9 @@ class Runner:
self.aborted = True self.aborted = True
self.running = False self.running = False
def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]: def initialize(
self, lang: str, model_name: str, dataset: List[str]
) -> Tuple[str, str, LoggerHandler, LogCallback]:
if self.running: if self.running:
return None, ALERTS["err_conflict"][lang], None, None return None, ALERTS["err_conflict"][lang], None, None
@@ -50,7 +52,9 @@ class Runner:
return model_name_or_path, "", logger_handler, trainer_callback return model_name_or_path, "", logger_handler, trainer_callback
def finalize(self, lang: str, finish_info: Optional[str] = None) -> str: def finalize(
self, lang: str, finish_info: Optional[str] = None
) -> str:
self.running = False self.running = False
torch_gc() torch_gc()
if self.aborted: if self.aborted:
@@ -77,12 +81,17 @@ class Runner:
batch_size: int, batch_size: int,
gradient_accumulation_steps: int, gradient_accumulation_steps: int,
lr_scheduler_type: str, lr_scheduler_type: str,
max_grad_norm: str,
dev_ratio: float, dev_ratio: float,
fp16: bool,
logging_steps: int, logging_steps: int,
save_steps: int, save_steps: int,
warmup_steps: int,
compute_type: str,
lora_rank: int,
lora_dropout: float,
lora_target: str,
output_dir: str output_dir: str
): ) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error: if error:
yield error yield error
@@ -99,11 +108,10 @@ class Runner:
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
do_train=True, do_train=True,
overwrite_cache=True, overwrite_cache=True,
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None, quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template, template=template,
source_prefix=source_prefix, source_prefix=source_prefix,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),
@@ -115,9 +123,15 @@ class Runner:
per_device_train_batch_size=batch_size, per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type, lr_scheduler_type=lr_scheduler_type,
fp16=fp16, max_grad_norm=float(max_grad_norm),
logging_steps=logging_steps, logging_steps=logging_steps,
save_steps=save_steps, save_steps=save_steps,
warmup_steps=warmup_steps,
fp16=(compute_type == "fp16"),
bf16=(compute_type == "bf16"),
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir) output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
) )
@@ -164,7 +178,7 @@ class Runner:
max_samples: str, max_samples: str,
batch_size: int, batch_size: int,
predict: bool predict: bool
): ) -> Generator[str, None, None]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error: if error:
yield error yield error
@@ -187,7 +201,7 @@ class Runner:
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None, quantization_bit=int(quantization_bit) if quantization_bit else None,
prompt_template=template, template=template,
source_prefix=source_prefix, source_prefix=source_prefix,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),

View File

@@ -3,11 +3,13 @@ import json
import gradio as gr import gradio as gr
import matplotlib.figure import matplotlib.figure
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from typing import Any, Dict, Tuple from typing import Any, Dict, Generator, List, Tuple
from datetime import datetime from datetime import datetime
from llmtuner.extras.ploting import smooth from llmtuner.extras.ploting import smooth
from llmtuner.webui.common import get_save_dir, DATA_CONFIG from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
from llmtuner.webui.locales import ALERTS
def format_info(log: str, tracker: dict) -> str: def format_info(log: str, tracker: dict) -> str:
@@ -83,3 +85,41 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
ax.set_xlabel("step") ax.set_xlabel("step")
ax.set_ylabel("loss") ax.set_ylabel("loss")
return fig return fig
def export_model(
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
) -> Generator[str, None, None]:
if not model_name:
yield ALERTS["err_no_model"][lang]
return
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
yield ALERTS["err_no_path"][lang]
return
if not checkpoints:
yield ALERTS["err_no_checkpoint"][lang]
return
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
if not save_dir:
yield ALERTS["err_no_save_dir"][lang]
return
args = dict(
model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type
)
yield ALERTS["info_exporting"][lang]
model_args, _, finetuning_args, _ = get_infer_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
tokenizer.save_pretrained(save_dir)
yield ALERTS["info_exported"][lang]

View File

@@ -1,4 +1,4 @@
from llmtuner import get_train_args, run_pt, run_sft, run_rm, run_ppo from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
def main(): def main():

View File

@@ -1,10 +1,10 @@
from llmtuner import create_ui from llmtuner.webui.interface import create_ui
def main(): def main():
demo = create_ui() demo = create_ui()
demo.queue() demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True) demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -5,7 +5,7 @@
import gradio as gr import gradio as gr
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from llmtuner import get_infer_args from llmtuner.tuner import get_infer_args
from llmtuner.webui.chat import WebChatModel from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box from llmtuner.webui.components.chatbot import create_chat_box
from llmtuner.webui.manager import Manager from llmtuner.webui.manager import Manager
@@ -24,20 +24,12 @@ def main():
manager = Manager([{"lang": lang}, chat_elems]) manager = Manager([{"lang": lang}, chat_elems])
demo.load( demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)
lang.change( lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
manager.gen_label,
[lang],
[lang] + [elem for elem in chat_elems.values()],
)
demo.queue() demo.queue()
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True) demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
if __name__ == "__main__": if __name__ == "__main__":