Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0b8e19b6a6 | ||
|
|
8e26eb374e | ||
|
|
9bba01a033 | ||
|
|
661890b8a1 | ||
|
|
772ad4ec6b | ||
|
|
6f65f8cb3b | ||
|
|
43e83548b9 | ||
|
|
dd3f3e9749 | ||
|
|
124f61b404 | ||
|
|
e8748cc6f3 | ||
|
|
fafec8b7a5 | ||
|
|
030daca686 | ||
|
|
ac587438f8 | ||
|
|
c145bbef3c | ||
|
|
745c46ee04 | ||
|
|
a707f5b502 | ||
|
|
dc2e801077 | ||
|
|
b56d5108b2 | ||
|
|
8e6b7034fe | ||
|
|
dad7ca6633 | ||
|
|
a1468139a5 | ||
|
|
49c90044ce | ||
|
|
0f7cdac207 | ||
|
|
c4e9694c6e | ||
|
|
2006a96570 | ||
|
|
5dcd95645f | ||
|
|
9b3304b054 | ||
|
|
e580d4ef41 | ||
|
|
64db4abc68 | ||
|
|
5ba0b80e5c | ||
|
|
7a43ff3d89 | ||
|
|
7e1a1d141a | ||
|
|
6d881f161b | ||
|
|
a02b3e6192 | ||
|
|
bcdee9fc19 | ||
|
|
8b688251be | ||
|
|
718f3382ad | ||
|
|
dc8283d3d7 |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1,2 +1,3 @@
|
||||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
||||
*.json filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
141
README.md
141
README.md
@@ -8,31 +8,40 @@
|
||||
|
||||
👋 Join our [WeChat](assets/wechat.jpg).
|
||||
|
||||
\[ English | [中文](README_zh.md) \]
|
||||
|
||||
## Changelog
|
||||
|
||||
[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset.
|
||||
|
||||
[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details.
|
||||
|
||||
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
|
||||
|
||||
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
||||
|
||||
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Please replace the Baichuan-13B model file with `tests/modeling_baichuan.py` and try `--model_name_or_path path_to_baichuan_model` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--prompt_template baichuan` argument when you are using the Baichuan-13B-Chat model.
|
||||
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model.
|
||||
|
||||
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
|
||||
|
||||
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--prompt_template intern` argument when you are using the InternLM-chat model.
|
||||
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model.
|
||||
|
||||
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
|
||||
|
||||
[23/06/29] We provide a **reproducible example** of training a chat model using instruction-following datasets, see this [HuggingFace Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details.
|
||||
[23/06/29] We provide a **reproducible example** of training a chat model using instruction-following datasets, see this [Hugging Face Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details.
|
||||
|
||||
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
|
||||
|
||||
[23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model. If you want to train with RTX3090, use `git checkout baichuan-7b-rtx3090` to switch to the `baichuan-7b-rtx3090` branch and try the `--baichuan_rtx_gpu true` argument. (Other RTX series GPUs can also be tried)
|
||||
[23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model.
|
||||
|
||||
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
|
||||
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
|
||||
|
||||
[23/05/31] Now we support training the **BLOOM & BLOOMZ** models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model.
|
||||
|
||||
## Supported Models
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
||||
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B)
|
||||
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
||||
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
|
||||
@@ -57,36 +66,41 @@
|
||||
## Provided Datasets
|
||||
|
||||
- For pre-training:
|
||||
- [Wiki Demo](data/wiki_demo.txt)
|
||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- For supervised fine-tuning:
|
||||
- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [Web QA (Chinese)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat](https://github.com/thunlp/UltraChat)
|
||||
- [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Open Assistant (Chinese)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [WebNovel (Chinese)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- For reward model training:
|
||||
- [HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Open Assistant (Chinese)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [GPT-4 Generated Data (Chinese)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- For reward modelling:
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
|
||||
Please refer to [data/README.md](data/README.md) for details.
|
||||
|
||||
Some datasets require confirmation before using them, so we recommend logging in with your HuggingFace account using these commands.
|
||||
Some datasets require confirmation before using them, so we recommend logging in with your Hugging Face account using these commands.
|
||||
|
||||
```bash
|
||||
pip install --upgrade huggingface_hub
|
||||
@@ -103,12 +117,6 @@ huggingface-cli login
|
||||
|
||||
And **powerful GPUs**!
|
||||
|
||||
If you want to enable quantized LoRA (QLoRA) on the Windows platform, you should install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
|
||||
|
||||
```bash
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Data Preparation (optional)
|
||||
@@ -120,6 +128,7 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
|
||||
### Dependence Installation (optional)
|
||||
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||
conda create -n llama_etuning python=3.10
|
||||
conda activate llama_etuning
|
||||
@@ -127,12 +136,20 @@ cd LLaMA-Efficient-Tuning
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
|
||||
|
||||
```bash
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
### All-in-one Web UI
|
||||
|
||||
```bash
|
||||
python src/train_web.py
|
||||
```
|
||||
|
||||
Currently the web UI only supports training on a single GPU.
|
||||
|
||||
### (Continually) Pre-Training
|
||||
|
||||
```bash
|
||||
@@ -141,6 +158,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset wiki_demo \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--output_dir path_to_pt_checkpoint \
|
||||
--overwrite_cache \
|
||||
@@ -163,6 +181,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--output_dir path_to_sft_checkpoint \
|
||||
--overwrite_cache \
|
||||
@@ -185,7 +204,10 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset comparison_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_rm_checkpoint \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
@@ -206,7 +228,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--reward_model path_to_rm_checkpoint \
|
||||
--output_dir path_to_ppo_checkpoint \
|
||||
@@ -217,7 +241,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--resume_lora_training False \
|
||||
--plot_loss
|
||||
```
|
||||
|
||||
@@ -260,34 +283,57 @@ use_cpu: false
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage pt \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_eval \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_eval_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 50 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
|
||||
|
||||
### Predict
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_predict \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_predict_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
### API Demo
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
See `http://localhost:8000/docs` for API documentation.
|
||||
Visit `http://localhost:8000/docs` for API documentation.
|
||||
|
||||
### CLI Demo
|
||||
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
@@ -296,6 +342,8 @@ python src/cli_demo.py \
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
@@ -304,10 +352,18 @@ python src/web_demo.py \
|
||||
```bash
|
||||
python src/export_model.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_export
|
||||
```
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)).
|
||||
- [ ] Implementing multi-query attention for faster inference.
|
||||
- [ ] Supporting full-parameter RLHF training.
|
||||
|
||||
## License
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
@@ -315,9 +371,10 @@ This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
Please follow the model licenses to use the corresponding model weights:
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
|
||||
- [LLaMA-2](https://ai.meta.com/llama/license/)
|
||||
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
|
||||
- [Falcon](LICENSE)
|
||||
- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
|
||||
|
||||
## Citation
|
||||
|
||||
399
README_zh.md
Normal file
399
README_zh.md
Normal file
@@ -0,0 +1,399 @@
|
||||
# LLaMA Efficient Tuning
|
||||
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/stargazers)
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/pulls)
|
||||
|
||||
👋 加入我们的[微信群](assets/wechat.jpg)。
|
||||
|
||||
\[ [English](README.md) | 中文 \]
|
||||
|
||||
## 更新日志
|
||||
|
||||
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming` 和 `--max_steps 100` 参数来流式加载数据集。
|
||||
|
||||
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。
|
||||
|
||||
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。
|
||||
|
||||
[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
|
||||
|
||||
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model` 和 `--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。
|
||||
|
||||
[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
|
||||
|
||||
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数。
|
||||
|
||||
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b` 和 `--lora_target query_key_value` 参数。
|
||||
|
||||
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft)。
|
||||
|
||||
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入任意基于 ChatGPT 的应用中。
|
||||
|
||||
[23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B` 和 `--lora_target W_pack` 参数。
|
||||
|
||||
[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 [QLoRA](https://github.com/artidoro/qlora))。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
||||
|
||||
[23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt` 和 `--lora_target query_key_value` 参数。
|
||||
|
||||
## 模型
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
||||
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B)
|
||||
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
||||
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
|
||||
- [InternLM](https://github.com/InternLM/InternLM) (7B)
|
||||
|
||||
## 微调方法
|
||||
|
||||
- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
|
||||
- 全参数微调
|
||||
- 部分参数微调
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
- [指令监督微调](https://arxiv.org/abs/2109.01652)
|
||||
- 全参数微调
|
||||
- 部分参数微调
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
- [人类反馈的强化学习(RLHF)](https://arxiv.org/abs/2203.02155)
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
|
||||
## 数据集
|
||||
|
||||
- 用于二次预训练:
|
||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- 用于指令监督微调:
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- 用于奖励模型训练:
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
|
||||
使用方法请参考 [data/README.md](data/README_zh.md) 文件。
|
||||
|
||||
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
|
||||
|
||||
```bash
|
||||
pip install --upgrade huggingface_hub
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## 软件依赖
|
||||
|
||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
||||
- jieba, rouge-chinese 和 nltk (用于评估)
|
||||
- gradio 和 matplotlib (用于网页端交互)
|
||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
||||
|
||||
以及 **强而有力的 GPU**!
|
||||
|
||||
## 如何使用
|
||||
|
||||
### 数据准备(可跳过)
|
||||
|
||||
关于数据集文件的格式,请参考 `data/example_dataset` 文件夹的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
|
||||
|
||||
注意:使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README.md`。
|
||||
|
||||
### 环境搭建(可跳过)
|
||||
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||
conda create -n llama_etuning python=3.10
|
||||
conda activate llama_etuning
|
||||
cd LLaMA-Efficient-Tuning
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.1.
|
||||
|
||||
```bash
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
### 浏览器一键微调/测试
|
||||
|
||||
```bash
|
||||
python src/train_web.py
|
||||
```
|
||||
|
||||
目前网页 UI 仅支持单卡训练。
|
||||
|
||||
### 二次预训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage pt \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset wiki_demo \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--output_dir path_to_pt_checkpoint \
|
||||
--overwrite_cache \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### 指令监督微调
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--output_dir path_to_sft_checkpoint \
|
||||
--overwrite_cache \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### 奖励模型训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage rm \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset comparison_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_rm_checkpoint \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### RLHF 训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage ppo \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_train \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--reward_model path_to_rm_checkpoint \
|
||||
--output_dir path_to_ppo_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--plot_loss
|
||||
```
|
||||
|
||||
### 多 GPU 分布式训练
|
||||
|
||||
```bash
|
||||
accelerate config # 首先配置分布式环境
|
||||
accelerate launch src/train_bash.py # 参数同上
|
||||
```
|
||||
|
||||
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数微调的 Accelerate 配置示例</summary>
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_clipping: 0.5
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 指标评估(BLEU分数和汉语ROUGE分数)
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_eval \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_eval_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128` 参数。
|
||||
|
||||
### 模型预测
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_predict \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_predict_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
### API 服务
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
关于 API 文档请见 `http://localhost:8000/docs`。
|
||||
|
||||
### 命令行测试
|
||||
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### 浏览器测试
|
||||
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### 导出微调模型
|
||||
|
||||
```bash
|
||||
python src/export_model.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_export
|
||||
```
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] 实现 flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention))。
|
||||
- [ ] 在推理阶段使用 Multi-query attention 进行加速。
|
||||
- [ ] 支持 RLHF 的全参数微调。
|
||||
|
||||
## 协议
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
|
||||
- [LLaMA-2](https://ai.meta.com/llama/license/)
|
||||
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
|
||||
- [Falcon](LICENSE)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
|
||||
|
||||
## 引用
|
||||
|
||||
如果您觉得此项目有帮助,请考虑以下列格式引用
|
||||
|
||||
```bibtex
|
||||
@Misc{llama-efficient-tuning,
|
||||
title = {LLaMA Efficient Tuning},
|
||||
author = {hiyouga},
|
||||
howpublished = {\url{https://github.com/hiyouga/LLaMA-Efficient-Tuning}},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
|
||||
## 致谢
|
||||
|
||||
本项目是 [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning) 的同类项目。采用了类似的代码结构和训练方法。
|
||||
|
||||
## Star History
|
||||
|
||||

|
||||
@@ -1,4 +1,5 @@
|
||||
Data format in `dataset_info.json`:
|
||||
If you are using a custom dataset, please provide your dataset definition in the following format in `dataset_info.json`.
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
|
||||
@@ -14,40 +15,4 @@ Data format in `dataset_info.json`:
|
||||
}
|
||||
```
|
||||
|
||||
`dataset_info.json` 中的数据集定义格式:
|
||||
```json
|
||||
"数据集名称": {
|
||||
"hf_hub_url": "HuggingFace上的项目地址(若指定,则忽略下列三个参数)",
|
||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||
"file_sha1": "数据集文件的SHA-1哈希值(可选)",
|
||||
"columns": {
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||
"query": "数据集代表请求的表头名称(默认:input)",
|
||||
"response": "数据集代表回答的表头名称(默认:output)",
|
||||
"history": "数据集代表历史对话的表头名称(默认:None)"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
部分预置数据集简介:
|
||||
|
||||
| 数据集名称 | 规模 | 描述 |
|
||||
| --- | --- | --- |
|
||||
| [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) | 52k | 斯坦福大学开源的 Alpaca 数据集,训练了 Alpaca 这类早期基于 LLaMA 的模型 |
|
||||
| [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) | 51k | 使用 ChatGPT 翻译的 Alpaca 数据集 |
|
||||
| [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) | 100k+ | 基于 GPT-4 的 self-instruction 数据集 |
|
||||
| [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN) | 2m | 包含约 200 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
||||
| [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN) | 1m | 包含约 100 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
||||
| [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) | 500k | 包含约 50 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
||||
| [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) | 400k | 包含约 40 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的个性化角色对话数据,包含角色介绍 |
|
||||
| [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) | 250k | 包含约 25 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文数学题数据,包含解题过程 |
|
||||
| [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) | 800k | 包含约 80 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的用户与助手的多轮对话 |
|
||||
| [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) | 100k+ | 包含日文、简繁体中文、英文等多类数据,数据集原用于 Guanaco 模型训练 |
|
||||
| [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) | 1.1M | 中文对话大模型 firefly(流萤)的中文数据集,包含多个 NLP 任务 |
|
||||
| [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) | 20k | 英文代码生成任务数据集 |
|
||||
| [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) | 6M | 用于微调的指令数据集集合 |
|
||||
| [Web QA](https://huggingface.co/datasets/suolyer/webqa) | 36k | 百度知道汇集的中文问答数据集 |
|
||||
| [UltraChat](https://github.com/thunlp/UltraChat) | 1.57M | 清华 NLP 发布的大规模多轮对话数据集 |
|
||||
|
||||
注:BELLE 数据集是由 ChatGPT 产生的数据集,不保证数据准确性,所有类 GPT 模型产生的 self-instruction 数据集均不能保证其准确性。
|
||||
where the `prompt` and `response` columns should contain non-empty values. The `query` column will be concatenated with the `prompt` column and used as input for the model. The `history` column should contain a list where each element is a string tuple representing a query-response pair.
|
||||
|
||||
18
data/README_zh.md
Normal file
18
data/README_zh.md
Normal file
@@ -0,0 +1,18 @@
|
||||
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中以如下格式提供您的数据集定义。
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"hf_hub_url": "HuggingFace上的项目地址(若指定,则忽略下列三个参数)",
|
||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||
"file_sha1": "数据集文件的SHA-1哈希值(可选)",
|
||||
"columns": {
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||
"query": "数据集代表请求的表头名称(默认:input)",
|
||||
"response": "数据集代表回答的表头名称(默认:output)",
|
||||
"history": "数据集代表历史对话的表头名称(默认:None)"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
其中 `prompt` 和 `response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复。
|
||||
@@ -1 +0,0 @@
|
||||
3779ddbc040543ab1834ef216c983d6fcc06cc9a
|
||||
@@ -1 +0,0 @@
|
||||
fc9a6a3458caca2af8dafc6181773fe10c6d8657
|
||||
@@ -1 +0,0 @@
|
||||
25508714b7879a1e5a6764ba7f979a980f549f1a
|
||||
@@ -1 +0,0 @@
|
||||
7cb6a7d11455bddc3d495750a2392683d775b184
|
||||
@@ -1 +0,0 @@
|
||||
f5cb08305ff5dc9c17a09809c54c8c8834aadc70
|
||||
@@ -1 +0,0 @@
|
||||
aee47b7b443496e37808d7f34ef10403ff99bcc3
|
||||
@@ -1 +0,0 @@
|
||||
274079ea921762be356de85b18f13fa60b7ba8cb
|
||||
@@ -1 +0,0 @@
|
||||
0a57fbc1d8cb08a8cd71c5eb8425cf59206ffed6
|
||||
@@ -1,2 +0,0 @@
|
||||
{"id": 0,"title": "大卫·亨利","content": "大卫·亨利\n\n大卫·克莱顿·亨利(David Clayton Henrie,),美国演员。近来在迪士尼频道原创电视影集《少年魔法师》(Wizards of Waverly Place)当中演出贾斯汀·鲁索(Justin Russo)一角。\n\n大卫·亨利出生在加州Mission Viejo,在凤凰城长大。他的胞弟劳伦斯·亨利(Lorenzo Henrie)也是演员。大卫·亨利就读夏安传统学校。家中是信奉罗马天主教。 \n\n大卫在2007年拍摄少年魔法师期间认识女演员露西·海尔(Lucy Hale),之后与其交往,于2009年分手。\n\n10岁时,大卫·亨利和SAG在凤凰城签订了合约,并开始走出去试镜。 9岁的时候,在沙加缅度进行商业拍摄,SAG董事建议大卫·亨利搬到洛杉矶。在10岁那年夏天,他和他的家人搬到了好莱坞。他预定他的前2支商业试镜,扮演主要角色为汉堡王和桂格燕麦。他初演电视节目为Providence。 \n\n到了13岁,大卫有了他的第一次重大突破,在福克斯公司的喜剧The Pitts饰演 Petey Pitt一角。大卫下出作品为的Hallmark movie为Monster Maker,和琳达布莱儿、乔治甘迺迪共同演出,并要求回来Hallmark movie公司。 \n\n在18岁时,大卫得到了迪士尼频道原创系列演出机会,该节目2007年10月12日首播。大卫2008年参加了迪士尼频道的游戏节目。他是绿色团队的队长,隔年,为旋风队队长。他在迪士尼原创电影《少年魔法师》之后在《酷爸的疯狂假期》中有饰演一角。\n"}
|
||||
{"id": 1,"title": "大卫·亨利","content": "大卫·亨利\n\n大卫·克莱顿·亨利(David Clayton Henrie,),美国演员。近来在迪士尼频道原创电视影集《少年魔法师》(Wizards of Waverly Place)当中演出贾斯汀·鲁索(Justin Russo)一角。\n\n大卫·亨利出生在加州Mission Viejo,在凤凰城长大。他的胞弟劳伦斯·亨利(Lorenzo Henrie)也是演员。大卫·亨利就读夏安传统学校。家中是信奉罗马天主教。 \n\n大卫在2007年拍摄少年魔法师期间认识女演员露西·海尔(Lucy Hale),之后与其交往,于2009年分手。\n\n10岁时,大卫·亨利和SAG在凤凰城签订了合约,并开始走出去试镜。 9岁的时候,在沙加缅度进行商业拍摄,SAG董事建议大卫·亨利搬到洛杉矶。在10岁那年夏天,他和他的家人搬到了好莱坞。他预定他的前2支商业试镜,扮演主要角色为汉堡王和桂格燕麦。他初演电视节目为Providence。 \n\n到了13岁,大卫有了他的第一次重大突破,在福克斯公司的喜剧The Pitts饰演 Petey Pitt一角。大卫下出作品为的Hallmark movie为Monster Maker,和琳达布莱儿、乔治甘迺迪共同演出,并要求回来Hallmark movie公司。 \n\n在18岁时,大卫得到了迪士尼频道原创系列演出机会,该节目2007年10月12日首播。大卫2008年参加了迪士尼频道的游戏节目。他是绿色团队的队长,隔年,为旋风队队长。他在迪士尼原创电影《少年魔法师》之后在《酷爸的疯狂假期》中有饰演一角。\n"}
|
||||
@@ -1,8 +1,8 @@
|
||||
torch>=1.13.1
|
||||
transformers>=4.29.1
|
||||
datasets>=2.12.0
|
||||
accelerate>=0.19.0
|
||||
peft>=0.3.0
|
||||
accelerate>=0.21.0
|
||||
peft>=0.4.0
|
||||
trl>=0.4.7
|
||||
sentencepiece
|
||||
jieba
|
||||
@@ -10,7 +10,7 @@ rouge-chinese
|
||||
nltk
|
||||
gradio>=3.36.0
|
||||
uvicorn
|
||||
pydantic
|
||||
fastapi
|
||||
pydantic==1.10.11
|
||||
fastapi==0.95.1
|
||||
sse-starlette
|
||||
matplotlib
|
||||
|
||||
@@ -5,9 +5,16 @@
|
||||
|
||||
import uvicorn
|
||||
|
||||
from llmtuner import create_app
|
||||
from llmtuner import ChatModel
|
||||
from llmtuner.api.app import create_app
|
||||
from llmtuner.tuner import get_infer_args
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel(*get_infer_args())
|
||||
app = create_app(chat_model)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = create_app()
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
main()
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
# Implements stream chat in command line for fine-tuned models.
|
||||
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
|
||||
from llmtuner import ChatModel, get_infer_args
|
||||
from llmtuner import ChatModel
|
||||
from llmtuner.tuner import get_infer_args
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Exports the fine-tuned model.
|
||||
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
||||
|
||||
from llmtuner import get_train_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner import get_train_args, load_model_and_tokenizer
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from llmtuner.api import create_app
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo
|
||||
from llmtuner.webui import create_ui
|
||||
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.1.4"
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from llmtuner.api.app import create_app
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -31,9 +30,7 @@ async def lifespan(app: FastAPI): # collects GPU memory
|
||||
torch_gc()
|
||||
|
||||
|
||||
def create_app():
|
||||
chat_model = ChatModel(*get_infer_args())
|
||||
|
||||
def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
@@ -96,7 +93,7 @@ def create_app():
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
for new_text in chat_model.stream_chat(
|
||||
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
||||
@@ -110,7 +107,7 @@ def create_app():
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
@@ -118,12 +115,13 @@ def create_app():
|
||||
finish_reason=Finish.STOP
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield json.dumps(chunk, ensure_ascii=False)
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield "[DONE]"
|
||||
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = create_app()
|
||||
chat_model = ChatModel(*get_infer_args())
|
||||
app = create_app(chat_model)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
|
||||
@@ -1,37 +1,54 @@
|
||||
import torch
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.template import Template
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
from llmtuner.extras.template import get_template
|
||||
from llmtuner.tuner import load_model_and_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
class ChatModel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
generating_args: GeneratingArguments
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments"
|
||||
) -> None:
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
self.template = Template(data_args.prompt_template)
|
||||
self.source_prefix = data_args.source_prefix if data_args.source_prefix else ""
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model))
|
||||
self.model = dispatch_model(self.model, device_map)
|
||||
else:
|
||||
self.model = self.model.cuda()
|
||||
|
||||
self.template = get_template(data_args.template)
|
||||
self.source_prefix = data_args.source_prefix
|
||||
self.generating_args = generating_args
|
||||
|
||||
def process_args(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
prefix = prefix if prefix else self.source_prefix
|
||||
prefix = prefix or self.source_prefix
|
||||
|
||||
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt")
|
||||
prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token)
|
||||
inputs = self.tokenizer([prompt], return_tensors="pt")
|
||||
inputs = inputs.to(self.model.device)
|
||||
prompt_length = len(inputs["input_ids"][0])
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", None)
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
@@ -42,6 +59,7 @@ class ChatModel:
|
||||
gen_kwargs = self.generating_args.to_dict()
|
||||
gen_kwargs.update(dict(
|
||||
input_ids=inputs["input_ids"],
|
||||
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
|
||||
temperature=temperature or gen_kwargs["temperature"],
|
||||
top_p=top_p or gen_kwargs["top_p"],
|
||||
top_k=top_k or gen_kwargs["top_k"],
|
||||
@@ -61,7 +79,11 @@ class ChatModel:
|
||||
|
||||
@torch.inference_mode()
|
||||
def chat(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[str, Tuple[int, int]]:
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
|
||||
generation_output = self.model.generate(**gen_kwargs)
|
||||
@@ -72,7 +94,11 @@ class ChatModel:
|
||||
|
||||
@torch.inference_mode()
|
||||
def stream_chat(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
@@ -81,5 +107,4 @@ class ChatModel:
|
||||
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
|
||||
thread.start()
|
||||
|
||||
for new_text in streamer:
|
||||
yield new_text
|
||||
yield from streamer
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from llmtuner.dsets.loader import get_dataset
|
||||
from llmtuner.dsets.preprocess import preprocess_dataset
|
||||
from llmtuner.dsets.utils import split_dataset
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments
|
||||
)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
def __init__(self, runner=None):
|
||||
self.runner = runner
|
||||
self.start_time = time.time()
|
||||
self.tracker = {}
|
||||
|
||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
||||
might take several inputs.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if "loss" not in state.log_history[-1]:
|
||||
return
|
||||
cur_time = time.time()
|
||||
cur_steps = state.log_history[-1].get("step")
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||
remaining_steps = state.max_steps - cur_steps
|
||||
remaining_time = remaining_steps * avg_time_per_step
|
||||
self.tracker = {
|
||||
"current_steps": cur_steps,
|
||||
"total_steps": state.max_steps,
|
||||
"loss": state.log_history[-1].get("loss", None),
|
||||
"reward": state.log_history[-1].get("reward", None),
|
||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
||||
"epoch": state.log_history[-1].get("epoch", None),
|
||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
||||
}
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.tracker) + "\n")
|
||||
@@ -1,40 +1,50 @@
|
||||
import os
|
||||
import hashlib
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from datasets import Dataset, concatenate_datasets, load_dataset
|
||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import ModelArguments, DataArguments
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
from llmtuner.hparams import ModelArguments, DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_dataset(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments
|
||||
) -> Dataset:
|
||||
|
||||
def checksum(file_path, hash):
|
||||
with open(file_path, "rb") as datafile:
|
||||
binary_data = datafile.read()
|
||||
sha1 = hashlib.sha1(binary_data).hexdigest()
|
||||
if sha1 != hash:
|
||||
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
||||
|
||||
ext2type = {
|
||||
EXT2TYPE = {
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json",
|
||||
"txt": "text"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||
if file_sha1 is None:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
return
|
||||
|
||||
if len(data_files) != 1:
|
||||
logger.warning("Checksum failed: too many files.")
|
||||
return
|
||||
|
||||
with open(data_files[0], "rb") as f:
|
||||
sha1 = hashlib.sha1(f.read()).hexdigest()
|
||||
if sha1 != file_sha1:
|
||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||
|
||||
|
||||
def get_dataset(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments"
|
||||
) -> "Dataset":
|
||||
max_samples = data_args.max_samples
|
||||
all_datasets: List[Dataset] = [] # support multiple datasets
|
||||
all_datasets: List["Dataset"] = [] # support multiple datasets
|
||||
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
|
||||
if dataset_attr.load_from == "hf_hub":
|
||||
@@ -47,60 +57,56 @@ def get_dataset(
|
||||
data_path = None
|
||||
data_files: List[str] = []
|
||||
|
||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory
|
||||
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
||||
|
||||
if data_path is None:
|
||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
|
||||
else:
|
||||
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
|
||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match."
|
||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError("File not found.")
|
||||
|
||||
assert data_path, "File extension must be txt, csv, json or jsonl."
|
||||
|
||||
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
|
||||
checksum(data_files[0], dataset_attr.dataset_sha1)
|
||||
else:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
|
||||
checksum(data_files, dataset_attr.dataset_sha1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
raw_datasets = load_dataset(
|
||||
dataset = load_dataset(
|
||||
data_path,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
streaming=data_args.streaming,
|
||||
use_auth_token=True if model_args.use_auth_token else None
|
||||
)
|
||||
dataset = raw_datasets[data_args.split]
|
||||
|
||||
if max_samples is not None:
|
||||
max_samples_temp = min(len(dataset), max_samples)
|
||||
dataset = dataset.select(range(max_samples_temp))
|
||||
|
||||
dummy_data = [None] * len(dataset)
|
||||
prefix_data = [dataset_attr.source_prefix] * len(dataset)
|
||||
for column_name, target_name in [
|
||||
("prompt_column", "prompt"),
|
||||
("query_column", "query"),
|
||||
("response_column", "response"),
|
||||
("history_column", "history")
|
||||
]: # every dataset will have 4 columns same as each other
|
||||
if getattr(dataset_attr, column_name) != target_name:
|
||||
if getattr(dataset_attr, column_name):
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
||||
else: # None or empty string
|
||||
dataset = dataset.add_column(target_name, dummy_data)
|
||||
dataset = dataset.add_column("prefix", prefix_data)
|
||||
for column_name in ["prompt", "query", "response", "history"]: # align datasets
|
||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
|
||||
if dataset_attr.source_prefix: # add prefix
|
||||
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix})
|
||||
|
||||
all_datasets.append(dataset)
|
||||
|
||||
if len(data_args.dataset_list) == 1:
|
||||
all_datasets = all_datasets[0]
|
||||
return all_datasets[0]
|
||||
elif data_args.mix_strategy == "concat":
|
||||
if data_args.streaming:
|
||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
||||
return concatenate_datasets(all_datasets)
|
||||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
||||
return interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)
|
||||
else:
|
||||
all_datasets = concatenate_datasets(all_datasets)
|
||||
|
||||
return all_datasets
|
||||
raise ValueError("Unknown mixing strategy.")
|
||||
|
||||
@@ -1,65 +1,63 @@
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
|
||||
from itertools import chain
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.template import Template
|
||||
from llmtuner.hparams import DataArguments
|
||||
from llmtuner.extras.template import get_template
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
def preprocess_dataset(
|
||||
dataset: Dataset,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
dataset: "Dataset",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> Dataset:
|
||||
) -> "Dataset":
|
||||
column_names = list(dataset.column_names or [])
|
||||
template = get_template(data_args.template)
|
||||
|
||||
column_names = list(dataset.column_names)
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
|
||||
# support question with a single answer or multiple answers
|
||||
def get_dialog(examples):
|
||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||
for i in range(len(examples["prompt"])):
|
||||
if examples["prompt"][i] and examples["response"][i]:
|
||||
query, answer = examples["prompt"][i], examples["response"][i]
|
||||
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
|
||||
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
|
||||
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
|
||||
yield dialog
|
||||
query, response = examples["prompt"][i], examples["response"][i]
|
||||
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
||||
history = history if "history" in examples and examples["history"][i] else []
|
||||
prefix = prefix if "prefix" in examples and examples["prefix"][i] else ""
|
||||
yield query, response, history, prefix
|
||||
|
||||
def preprocess_pretrain_dataset(examples):
|
||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
|
||||
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
||||
concatenated_ids = list(chain(*text_ids))
|
||||
total_length = len(concatenated_ids)
|
||||
block_size = data_args.max_source_length - 1
|
||||
tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.max_source_length
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of max_source_length
|
||||
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
|
||||
for i in range(0, total_length, block_size)]
|
||||
return {
|
||||
"input_ids": result,
|
||||
"labels": result.copy()
|
||||
result = {
|
||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
def preprocess_supervised_dataset(examples):
|
||||
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for input with history, we build multiple input-label pairs just like:
|
||||
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
for dialog in get_dialog(examples):
|
||||
for query, response, history, prefix in construct_example(examples):
|
||||
input_ids, labels = [], []
|
||||
|
||||
for i in range(len(dialog) // 2):
|
||||
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=(i == 0))
|
||||
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
||||
for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)):
|
||||
source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0))
|
||||
target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
@@ -73,19 +71,20 @@ def preprocess_dataset(
|
||||
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples):
|
||||
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
||||
model_inputs = {"input_ids": [], "labels": []}
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
for query, response, history, prefix in construct_example(examples):
|
||||
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
|
||||
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||
target_ids = tokenizer.encode(text=answer, add_special_tokens=True)
|
||||
target_ids = tokenizer.encode(text=response, add_special_tokens=True)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
@@ -93,6 +92,7 @@ def preprocess_dataset(
|
||||
target_ids = target_ids[:data_args.max_target_length]
|
||||
|
||||
model_inputs["input_ids"].append(source_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(source_ids))
|
||||
model_inputs["labels"].append(target_ids)
|
||||
|
||||
return model_inputs
|
||||
@@ -100,12 +100,12 @@ def preprocess_dataset(
|
||||
def preprocess_pairwise_dataset(examples):
|
||||
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
||||
for dialog in get_dialog(examples):
|
||||
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
||||
for query, response, history, prefix in construct_example(examples):
|
||||
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
|
||||
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
||||
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
||||
accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False)
|
||||
reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
@@ -141,32 +141,44 @@ def preprocess_dataset(
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
if stage == "pt":
|
||||
dataset = dataset.filter(lambda example: example["prompt"])
|
||||
preprocess_function = preprocess_pretrain_dataset
|
||||
elif stage == "sft":
|
||||
preprocess_function = preprocess_unsupervised_dataset \
|
||||
if training_args.predict_with_generate else preprocess_supervised_dataset
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
|
||||
preprocess_function = preprocess_supervised_dataset
|
||||
elif stage == "rm":
|
||||
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
|
||||
preprocess_function = preprocess_pairwise_dataset
|
||||
elif stage == "ppo":
|
||||
else:
|
||||
dataset = dataset.filter(lambda example: example["prompt"])
|
||||
preprocess_function = preprocess_unsupervised_dataset
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
dataset = dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
desc="Running tokenizer on dataset"
|
||||
)
|
||||
|
||||
dataset = dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
|
||||
|
||||
if stage == "pt":
|
||||
print_unsupervised_dataset_example(dataset[0])
|
||||
print_unsupervised_dataset_example(next(iter(dataset)))
|
||||
elif stage == "sft":
|
||||
print_supervised_dataset_example(dataset[0])
|
||||
print_supervised_dataset_example(next(iter(dataset)))
|
||||
elif stage == "rm":
|
||||
print_pairwise_dataset_example(dataset[0])
|
||||
print_pairwise_dataset_example(next(iter(dataset)))
|
||||
elif stage == "ppo":
|
||||
print_unsupervised_dataset_example(dataset[0])
|
||||
print_unsupervised_dataset_example(next(iter(dataset)))
|
||||
|
||||
return dataset
|
||||
|
||||
15
src/llmtuner/dsets/utils.py
Normal file
15
src/llmtuner/dsets/utils.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]:
|
||||
if do_train:
|
||||
if dev_ratio > 1e-6: # Split the dataset
|
||||
dataset = dataset.train_test_split(test_size=dev_ratio)
|
||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
return {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
return {"eval_dataset": dataset}
|
||||
@@ -1,16 +1,13 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments
|
||||
)
|
||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
||||
from transformers.training_args import TrainingArguments
|
||||
from transformers import TrainerCallback
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
@@ -20,13 +17,13 @@ class LogCallback(TrainerCallback):
|
||||
self.start_time = time.time()
|
||||
self.tracker = {}
|
||||
|
||||
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
||||
might take several inputs.
|
||||
@@ -35,7 +32,7 @@ class LogCallback(TrainerCallback):
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
@@ -43,7 +40,7 @@ class LogCallback(TrainerCallback):
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
|
||||
@@ -13,6 +13,12 @@ SUPPORTED_MODELS = {
|
||||
"LLaMA-13B": "huggyllama/llama-13b",
|
||||
"LLaMA-30B": "huggyllama/llama-30b",
|
||||
"LLaMA-65B": "huggyllama/llama-65b",
|
||||
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
||||
"LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
|
||||
"LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
|
||||
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"BLOOM-560M": "bigscience/bloom-560m",
|
||||
"BLOOM-3B": "bigscience/bloom-3b",
|
||||
"BLOOM-7B1": "bigscience/bloom-7b1",
|
||||
@@ -30,8 +36,9 @@ SUPPORTED_MODELS = {
|
||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
|
||||
}
|
||||
|
||||
DEFAULT_MODULE = { # will be deprecated
|
||||
DEFAULT_MODULE = {
|
||||
"LLaMA": "q_proj,v_proj",
|
||||
"LLaMA2": "q_proj,v_proj",
|
||||
"BLOOM": "query_key_value",
|
||||
"BLOOMZ": "query_key_value",
|
||||
"Falcon": "query_key_value",
|
||||
|
||||
@@ -16,8 +16,16 @@ class LoggerHandler(logging.Handler):
|
||||
self.log += "\n\n"
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
def reset_logging():
|
||||
r"""
|
||||
Removes basic config of root logger
|
||||
"""
|
||||
root = logging.getLogger()
|
||||
list(map(root.removeHandler, root.handlers))
|
||||
list(map(root.removeFilter, root.filters))
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S"
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import torch
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
r"""
|
||||
@@ -28,7 +30,7 @@ class AverageMeter:
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
# Avoid runtime error in model.generate(do_sample=True).
|
||||
# Avoids runtime error in model.generate(do_sample=True).
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
@@ -44,29 +46,37 @@ def get_logits_processor() -> LogitsProcessorList:
|
||||
return logits_processor
|
||||
|
||||
|
||||
def print_trainable_params(model: torch.nn.Module) -> None:
|
||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and number of all parameters in the model.
|
||||
"""
|
||||
trainable_params, all_param = 0, 0
|
||||
for param in model.parameters():
|
||||
num_params = param.numel()
|
||||
# if using DS Zero 3 and the weights are initialized empty
|
||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||
num_params = param.ds_numel
|
||||
|
||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
num_params = num_params * 2
|
||||
|
||||
all_param += num_params
|
||||
if param.requires_grad:
|
||||
trainable_params += num_params
|
||||
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param))
|
||||
|
||||
return trainable_params, all_param
|
||||
|
||||
|
||||
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
|
||||
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
||||
def prepare_model_for_training(
|
||||
model: PreTrainedModel,
|
||||
model: "PreTrainedModel",
|
||||
finetuning_type: str,
|
||||
output_embedding_layer_name: Optional[str] = "lm_head",
|
||||
output_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||
) -> PreTrainedModel:
|
||||
) -> "PreTrainedModel":
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
@@ -83,19 +93,23 @@ def prepare_model_for_training(
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
|
||||
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
|
||||
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
|
||||
input_dtype = output_embedding_layer.weight.dtype
|
||||
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||
if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"):
|
||||
model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728)
|
||||
|
||||
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
||||
input_dtype = output_layer.weight.dtype
|
||||
|
||||
class CastOutputToFloat(torch.nn.Sequential):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||
|
||||
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
||||
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
|
||||
@@ -12,8 +12,8 @@ from llmtuner.extras.logging import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
|
||||
state_dict = model.state_dict()
|
||||
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
state_dict: Dict[str, torch.Tensor] = model.state_dict()
|
||||
filtered_state_dict = {}
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
|
||||
@@ -3,66 +3,61 @@ from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Format:
|
||||
class Template:
|
||||
|
||||
prefix: str
|
||||
prompt: str
|
||||
sep: str
|
||||
use_history: bool
|
||||
|
||||
|
||||
templates: Dict[str, Format] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
|
||||
name: str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.name in templates:
|
||||
self.prefix = templates[self.name].prefix
|
||||
self.prompt = templates[self.name].prompt
|
||||
self.sep = templates[self.name].sep
|
||||
self.use_history = templates[self.name].use_history
|
||||
else:
|
||||
raise ValueError("Template {} does not exist.".format(self.name))
|
||||
|
||||
def get_prompt(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = "",
|
||||
eos_token: Optional[str] = "</s>"
|
||||
) -> str:
|
||||
r"""
|
||||
Returns a string containing prompt without response.
|
||||
"""
|
||||
return "".join(self._format_example(query, history, prefix))
|
||||
return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix)))
|
||||
|
||||
def get_dialog(
|
||||
self, query: str, resp: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||
) -> List[str]:
|
||||
self,
|
||||
query: str,
|
||||
resp: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = ""
|
||||
) -> List[Tuple[str, str]]:
|
||||
r"""
|
||||
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
|
||||
Returns a list containing prompt-response pairs.
|
||||
"""
|
||||
return self._format_example(query, history, prefix) + [resp]
|
||||
result = self._format_example(query, history, prefix)
|
||||
result[-1][-1] = resp
|
||||
return result
|
||||
|
||||
def _format_example(
|
||||
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = ""
|
||||
) -> List[str]:
|
||||
prefix = prefix if prefix else self.prefix # use prefix if provided
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = ""
|
||||
) -> List[Tuple[str, str]]:
|
||||
prefix = prefix or self.prefix # use prefix if provided
|
||||
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
|
||||
history = history if (history and self.use_history) else []
|
||||
history = history + [(query, "<dummy>")]
|
||||
convs = []
|
||||
for turn_idx, (user_query, bot_resp) in enumerate(history):
|
||||
if turn_idx == 0:
|
||||
convs.append(prefix + self.prompt.format(query=user_query))
|
||||
convs.append(bot_resp)
|
||||
else:
|
||||
convs.append(self.sep + self.prompt.format(query=user_query))
|
||||
convs.append(bot_resp)
|
||||
return convs[:-1] # drop last
|
||||
history = history + [(query, "")]
|
||||
convs = [
|
||||
[(self.sep if turn_idx else prefix) + self.prompt.format(query=query_i), resp_i]
|
||||
for turn_idx, (query_i, resp_i) in enumerate(history)
|
||||
]
|
||||
return convs
|
||||
|
||||
|
||||
templates: Dict[str, Template] = {}
|
||||
|
||||
|
||||
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
|
||||
templates[name] = Format(
|
||||
templates[name] = Template(
|
||||
prefix=prefix,
|
||||
prompt=prompt,
|
||||
sep=sep,
|
||||
@@ -70,6 +65,12 @@ def register_template(name: str, prefix: str, prompt: str, sep: str, use_history
|
||||
)
|
||||
|
||||
|
||||
def get_template(name: str) -> Template:
|
||||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
return template
|
||||
|
||||
|
||||
r"""
|
||||
Supports language model inference without histories.
|
||||
"""
|
||||
@@ -95,6 +96,27 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
|
||||
https://huggingface.co/meta-llama/Llama-2-13b-chat-hf
|
||||
https://huggingface.co/meta-llama/Llama-2-70b-chat-hf
|
||||
"""
|
||||
register_template(
|
||||
name="llama2",
|
||||
prefix="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
|
||||
prompt=" [INST] {query} [/INST] ",
|
||||
sep="",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
||||
https://github.com/ymcui/Chinese-LLaMA-Alpaca
|
||||
@@ -118,7 +140,7 @@ register_template(
|
||||
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
prompt="USER: {query} ASSISTANT: ",
|
||||
sep="</s>",
|
||||
sep="",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
@@ -202,8 +224,8 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
||||
register_template(
|
||||
name="baichuan",
|
||||
prefix="",
|
||||
prompt=" <reserved_102> {query} <reserved_103> ",
|
||||
sep="</s>",
|
||||
prompt="<reserved_102>{query}<reserved_103>",
|
||||
sep="",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from typing import List, Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@@ -16,10 +16,10 @@ class DatasetAttr:
|
||||
return self.dataset_name
|
||||
|
||||
def __post_init__(self):
|
||||
self.prompt_column = "instruction"
|
||||
self.query_column = "input"
|
||||
self.response_column = "output"
|
||||
self.history_column = None
|
||||
self.prompt = "instruction"
|
||||
self.query = "input"
|
||||
self.response = "output"
|
||||
self.history = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -27,8 +27,11 @@ class DataArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||
"""
|
||||
template: str = field(
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
default="alpaca_zh",
|
||||
default="alpaca_en",
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
|
||||
)
|
||||
dataset_dir: Optional[str] = field(
|
||||
@@ -39,6 +42,18 @@ class DataArguments:
|
||||
default="train",
|
||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||
)
|
||||
streaming: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable streaming mode."}
|
||||
)
|
||||
buffer_size: Optional[int] = field(
|
||||
default=16384,
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
|
||||
)
|
||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||
default="concat",
|
||||
metadata={"help": "Strategy to use in dataset mixing."}
|
||||
)
|
||||
overwrite_cache: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
||||
@@ -75,10 +90,6 @@ class DataArguments:
|
||||
default=0,
|
||||
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
||||
)
|
||||
prompt_template: Optional[str] = field(
|
||||
default="default",
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
)
|
||||
|
||||
def init_for_training(self): # support mixing multiple datasets
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||
@@ -111,9 +122,9 @@ class DataArguments:
|
||||
dataset_attr.source_prefix = prefix_list[i]
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
||||
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
||||
|
||||
self.dataset_list.append(dataset_attr)
|
||||
@@ -16,9 +16,10 @@ class FinetuningArguments:
|
||||
default=32,
|
||||
metadata={"help": "Number of decoder blocks in the model. \
|
||||
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
||||
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
|
||||
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
||||
Falcon choices: [\"32\", \"60\"], \
|
||||
Baichuan choices: [\"32\"]"}
|
||||
Baichuan choices: [\"32\", \"40\"]"}
|
||||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
default=3,
|
||||
@@ -27,7 +28,7 @@ class FinetuningArguments:
|
||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
||||
default="mlp",
|
||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
||||
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
||||
)
|
||||
@@ -46,7 +47,7 @@ class FinetuningArguments:
|
||||
lora_target: Optional[str] = field(
|
||||
default="q_proj,v_proj",
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
||||
@dataclass
|
||||
class GeneralArguments:
|
||||
"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
Arguments pertaining to which stage we are going to perform.
|
||||
"""
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
|
||||
default="sft",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from peft import (
|
||||
PeftModel,
|
||||
TaskType,
|
||||
@@ -12,19 +12,22 @@ from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import load_trainable_params
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def init_adapter(
|
||||
model: PreTrainedModel,
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
is_mergeable: bool
|
||||
) -> PreTrainedModel:
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Literal, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
@@ -10,30 +10,34 @@ from transformers import (
|
||||
)
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.tuner.core.adapter import init_adapter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
check_min_version("4.29.1")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
|
||||
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
model_args: ModelArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
|
||||
@@ -80,9 +84,6 @@ def load_model_and_tokenizer(
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
|
||||
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
|
||||
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
@@ -92,11 +93,13 @@ def load_model_and_tokenizer(
|
||||
)
|
||||
|
||||
is_mergeable = False
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
if not is_trainable: # `device_map=auto` should be used for inference only
|
||||
config_kwargs["device_map"] = "auto"
|
||||
if (
|
||||
model_args.quantization_bit is not None
|
||||
or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled())
|
||||
):
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
|
||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
@@ -108,7 +111,7 @@ def load_model_and_tokenizer(
|
||||
model_to_load,
|
||||
config=config,
|
||||
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
@@ -126,6 +129,7 @@ def load_model_and_tokenizer(
|
||||
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
reset_logging()
|
||||
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
||||
@@ -146,6 +150,9 @@ def load_model_and_tokenizer(
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
||||
|
||||
print_trainable_params(model)
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
))
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
@@ -19,20 +19,39 @@ from llmtuner.hparams import (
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
return parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
return parser.parse_args_into_dataclasses()
|
||||
|
||||
|
||||
def parse_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def parse_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
|
||||
|
||||
if args is not None:
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
@@ -54,12 +73,21 @@ def get_train_args(
|
||||
assert not (training_args.do_train and training_args.predict_with_generate), \
|
||||
"`predict_with_generate` cannot be set as True while training."
|
||||
|
||||
assert (not training_args.do_predict) or training_args.predict_with_generate, \
|
||||
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
|
||||
"Please enable `predict_with_generate` to save model predictions."
|
||||
|
||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||
"Quantization is only compatible with the LoRA method."
|
||||
|
||||
assert not (training_args.max_steps == -1 and data_args.streaming), \
|
||||
"Please specify `max_steps` in streaming mode."
|
||||
|
||||
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
|
||||
"Streaming mode does not support evaluation currently."
|
||||
|
||||
assert not (general_args.stage == "ppo" and data_args.streaming), \
|
||||
"Streaming mode does not suppport PPO training currently."
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
@@ -73,13 +101,22 @@ def get_train_args(
|
||||
if training_args.do_train and (not training_args.fp16):
|
||||
logger.warning("We recommend enable fp16 mixed precision training.")
|
||||
|
||||
if data_args.prompt_template == "default":
|
||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||
|
||||
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
|
||||
if (
|
||||
training_args.local_rank != -1
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
|
||||
if data_args.max_samples is not None and data_args.streaming:
|
||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
||||
data_args.max_samples = None
|
||||
|
||||
if data_args.dev_ratio > 1e-6 and data_args.streaming:
|
||||
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
|
||||
data_args.dev_ratio = 0
|
||||
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
@@ -91,10 +128,10 @@ def get_train_args(
|
||||
model_args.compute_dtype = torch.float32
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
|
||||
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format(
|
||||
training_args.local_rank, training_args.device, training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1), training_args.fp16
|
||||
))
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
@@ -106,17 +143,7 @@ def get_train_args(
|
||||
def get_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
|
||||
|
||||
if args is not None:
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
|
||||
else:
|
||||
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
|
||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||
|
||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||
"Quantization is only compatible with the LoRA method."
|
||||
@@ -128,7 +155,4 @@ def get_infer_args(
|
||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
||||
"Quantized model only accepts a single checkpoint."
|
||||
|
||||
if data_args.prompt_template == "default":
|
||||
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Dict, Optional
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
|
||||
from transformers.modeling_utils import PreTrainedModel, unwrap_model
|
||||
from peft import PeftModel
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -20,7 +24,7 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self._remove_log()
|
||||
@@ -41,29 +45,34 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||
|
||||
model = unwrap_model(self.model)
|
||||
|
||||
if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
|
||||
backbone_model = getattr(model, "pretrained_model")
|
||||
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
else:
|
||||
backbone_model = model
|
||||
if isinstance(model, PreTrainedModelWrapper):
|
||||
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
|
||||
model_state_dict = state_dict or model.state_dict()
|
||||
v_head_state_dict = {
|
||||
name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
|
||||
for name in model_state_dict.keys() if name.startswith("v_head.")
|
||||
}
|
||||
|
||||
if self.finetuning_args.finetuning_type == "lora":
|
||||
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
|
||||
else: # freeze/full tuning
|
||||
backbone_model.config.use_cache = True
|
||||
backbone_model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=get_state_dict(backbone_model),
|
||||
safe_serialization=self.args.save_safetensors
|
||||
)
|
||||
backbone_model.config.use_cache = False
|
||||
if self.tokenizer is not None:
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
model = model.pretrained_model
|
||||
|
||||
state_dict = state_dict or get_state_dict(model)
|
||||
if isinstance(model, (PeftModel, PreTrainedModel)):
|
||||
model.config.use_cache = True
|
||||
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
|
||||
model.config.use_cache = False
|
||||
else:
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
|
||||
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||
f.write(self.args.to_json_string() + "\n")
|
||||
|
||||
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
||||
|
||||
def _load_best_model(self):
|
||||
@@ -73,16 +82,15 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||
|
||||
model = unwrap_model(self.model)
|
||||
backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
|
||||
|
||||
if self.finetuning_args.finetuning_type == "lora":
|
||||
backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
|
||||
if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
})
|
||||
if isinstance(model, PreTrainedModelWrapper):
|
||||
model.v_head.load_state_dict(torch.load(
|
||||
os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
|
||||
))
|
||||
model = model.pretrained_model
|
||||
|
||||
if isinstance(model, PeftModel):
|
||||
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
||||
else: # freeze/full-tuning
|
||||
load_trainable_params(backbone_model, self.state.best_model_checkpoint)
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
|
||||
@@ -2,21 +2,25 @@ import os
|
||||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl
|
||||
from transformers import TrainerState, TrainerControl
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from trl import PPOTrainer
|
||||
from trl.core import LengthSampler
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, get_logits_processor
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -25,11 +29,12 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: List[LogCallback],
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: List["LogCallback"],
|
||||
**kwargs
|
||||
):
|
||||
PPOTrainer.__init__(self, **kwargs)
|
||||
@@ -66,7 +71,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
|
||||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = {
|
||||
@@ -107,7 +112,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
# Compute rewards
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
with torch.no_grad():
|
||||
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
|
||||
_, _, values = self.model(
|
||||
**self.prepare_model_inputs(queries, responses),
|
||||
output_hidden_states=True,
|
||||
return_dict=True
|
||||
)
|
||||
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import torch
|
||||
from typing import Dict, List, Literal, Optional, Tuple
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save default head temporarily
|
||||
valuehead_state_dict = model.v_head.state_dict()
|
||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
|
||||
@@ -19,10 +21,10 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
|
||||
|
||||
|
||||
def cast_layernorm_dtype(
|
||||
model: AutoModelForCausalLMWithValueHead,
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
layer_norm_names: List[str] = LAYERNORM_NAMES,
|
||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
||||
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
|
||||
|
||||
layer_norm_state_dict = {}
|
||||
|
||||
|
||||
@@ -2,26 +2,30 @@
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
from trl import PPOConfig
|
||||
from torch.optim import AdamW
|
||||
from typing import Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, TrainerCallback
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from transformers.optimization import get_scheduler
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
def run_ppo(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||
|
||||
@@ -1,24 +1,27 @@
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
|
||||
|
||||
import math
|
||||
from typing import Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
def run_pt(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||
@@ -28,16 +31,6 @@ def run_pt(
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
if data_args.dev_ratio > 1e-6:
|
||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
trainer_kwargs = {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
trainer_kwargs = {"eval_dataset": dataset}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
@@ -46,7 +39,7 @@ def run_pt(
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**trainer_kwargs
|
||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||
)
|
||||
|
||||
# Training
|
||||
|
||||
@@ -15,5 +15,8 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
|
||||
features = [
|
||||
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])}
|
||||
for key in ("accept_ids", "reject_ids") for feature in features
|
||||
]
|
||||
return super().__call__(features)
|
||||
|
||||
@@ -1,9 +1,18 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PairwisePeftTrainer(PeftTrainer):
|
||||
r"""
|
||||
@@ -16,7 +25,7 @@ class PairwisePeftTrainer(PeftTrainer):
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
model: "PreTrainedModel",
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
return_outputs: Optional[bool] = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
@@ -32,7 +41,30 @@ class PairwisePeftTrainer(PeftTrainer):
|
||||
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
|
||||
"""
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
_, _, values = model(**inputs)
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
||||
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
||||
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predict_results: "PredictionOutput"
|
||||
) -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
|
||||
A custom behavior that not contained in Seq2SeqTrainer.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
||||
|
||||
acc_scores, rej_scores = predict_results.predictions
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
for acc_score, rej_score in zip(acc_scores, rej_scores):
|
||||
res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)}))
|
||||
writer.write("\n".join(res))
|
||||
|
||||
@@ -2,25 +2,27 @@
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
||||
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
|
||||
from typing import Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
def run_rm(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
@@ -29,16 +31,6 @@ def run_rm(
|
||||
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
if data_args.dev_ratio > 1e-6:
|
||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
trainer_kwargs = {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
trainer_kwargs = {"eval_dataset": dataset}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PairwisePeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
@@ -48,7 +40,7 @@ def run_rm(
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
compute_metrics=compute_accuracy,
|
||||
**trainer_kwargs
|
||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||
)
|
||||
|
||||
# Training
|
||||
@@ -66,3 +58,10 @@ def run_rm(
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
predict_results = trainer.predict(dataset, metric_key_prefix="predict")
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
|
||||
|
||||
import jieba
|
||||
from rouge_chinese import Rouge
|
||||
@@ -9,6 +8,9 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputeMetrics:
|
||||
@@ -16,7 +18,7 @@ class ComputeMetrics:
|
||||
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
||||
"""
|
||||
|
||||
tokenizer: PreTrainedTokenizer
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
|
||||
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
r"""
|
||||
|
||||
@@ -3,13 +3,15 @@ import json
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from transformers.trainer import PredictionOutput
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -36,31 +38,44 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
||||
if label_len > prompt_len:
|
||||
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
|
||||
if "attention_mask" in inputs:
|
||||
inputs["attention_mask"] = self._pad_tensors_to_target_len(
|
||||
inputs["attention_mask"], inputs["labels"], pad_token_id=0
|
||||
)
|
||||
if "position_ids" in inputs:
|
||||
inputs["position_ids"] = self._pad_tensors_to_target_len(
|
||||
inputs["position_ids"], inputs["labels"], pad_token_id=0
|
||||
)
|
||||
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||
)
|
||||
generated_tokens = generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
|
||||
generated_tokens = (
|
||||
generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
|
||||
)
|
||||
|
||||
return (loss, generated_tokens, labels)
|
||||
|
||||
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
|
||||
def _pad_tensors_to_target_len(
|
||||
self,
|
||||
src_tensor: torch.Tensor,
|
||||
tgt_tensor: torch.Tensor,
|
||||
pad_token_id: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Pads the tensor to the same length as the target tensor.
|
||||
|
||||
Should only be called when predict_with_generate=True.
|
||||
"""
|
||||
if pad_token_id is None:
|
||||
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
|
||||
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
||||
# If PAD token is not defined at least EOS token has to be defined
|
||||
pad_token_id = (
|
||||
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
|
||||
)
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
else:
|
||||
if self.model.config.pad_token_id is not None:
|
||||
pad_token_id = self.model.config.pad_token_id
|
||||
else:
|
||||
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
|
||||
raise ValueError("Pad_token_id must be set in the configuration of the model.")
|
||||
|
||||
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
||||
@@ -68,7 +83,7 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predict_results: PredictionOutput
|
||||
predict_results: "PredictionOutput"
|
||||
) -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
|
||||
@@ -1,25 +1,28 @@
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
||||
|
||||
from typing import Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, TrainerCallback
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
def run_sft(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: Seq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
@@ -35,16 +38,6 @@ def run_sft(
|
||||
training_args.generation_num_beams = data_args.eval_num_beams if \
|
||||
data_args.eval_num_beams is not None else training_args.generation_num_beams
|
||||
|
||||
# Split the dataset
|
||||
if training_args.do_train:
|
||||
if data_args.dev_ratio > 1e-6:
|
||||
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
||||
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
trainer_kwargs = {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
trainer_kwargs = {"eval_dataset": dataset}
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqPeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
@@ -54,7 +47,7 @@ def run_sft(
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||
**trainer_kwargs
|
||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||
)
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from llmtuner.webui.interface import create_ui
|
||||
|
||||
@@ -54,7 +54,7 @@ class WebChatModel(ChatModel):
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
prompt_template=template,
|
||||
template=template,
|
||||
source_prefix=source_prefix
|
||||
)
|
||||
super().__init__(*get_infer_args(args))
|
||||
@@ -73,6 +73,7 @@ class WebChatModel(ChatModel):
|
||||
chatbot: List[Tuple[str, str]],
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
prefix: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float
|
||||
@@ -80,9 +81,17 @@ class WebChatModel(ChatModel):
|
||||
chatbot.append([query, ""])
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query, history, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
response = self.postprocess(response)
|
||||
new_history = history + [(query, response)]
|
||||
chatbot[-1] = [query, response]
|
||||
yield chatbot, new_history
|
||||
|
||||
def postprocess(self, response: str) -> str:
|
||||
blocks = response.split("```")
|
||||
for i, block in enumerate(blocks):
|
||||
if i % 2 == 0:
|
||||
blocks[i] = block.replace("<", "<").replace(">", ">")
|
||||
return "```".join(blocks)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from llmtuner.webui.components.eval import create_eval_tab
|
||||
from llmtuner.webui.components.infer import create_infer_tab
|
||||
from llmtuner.webui.components.top import create_top
|
||||
from llmtuner.webui.components.sft import create_sft_tab
|
||||
from llmtuner.webui.components.eval import create_eval_tab
|
||||
from llmtuner.webui.components.infer import create_infer_tab
|
||||
from llmtuner.webui.components.export import create_export_tab
|
||||
|
||||
@@ -1,42 +1,37 @@
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
|
||||
|
||||
def create_chat_box(
|
||||
chat_model: WebChatModel,
|
||||
chat_model: "WebChatModel",
|
||||
visible: Optional[bool] = False
|
||||
) -> Tuple[Block, Component, Component, Dict[str, Component]]:
|
||||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Box(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot()
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
with gr.Column(scale=12):
|
||||
prefix = gr.Textbox(show_label=False)
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
|
||||
with gr.Column(min_width=32, scale=1):
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
clear_btn = gr.Button()
|
||||
max_new_tokens = gr.Slider(
|
||||
10, 2048, value=chat_model.generating_args.max_new_tokens, step=1, interactive=True
|
||||
)
|
||||
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01, interactive=True)
|
||||
temperature = gr.Slider(
|
||||
0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01, interactive=True
|
||||
)
|
||||
max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01)
|
||||
|
||||
history = gr.State([])
|
||||
|
||||
submit_btn.click(
|
||||
chat_model.predict,
|
||||
[chatbot, query, history, max_new_tokens, top_p, temperature],
|
||||
[chatbot, query, history, prefix, max_new_tokens, top_p, temperature],
|
||||
[chatbot, history],
|
||||
show_progress=True
|
||||
).then(
|
||||
@@ -46,6 +41,7 @@ def create_chat_box(
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
||||
|
||||
return chat_box, chatbot, history, dict(
|
||||
prefix=prefix,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
clear_btn=clear_btn,
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import gradio as gr
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
def create_preview_box() -> Tuple[Block, Component, Component, Component]:
|
||||
def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
|
||||
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
|
||||
with gr.Row():
|
||||
preview_count = gr.Number(interactive=False)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from typing import Dict
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.utils import can_preview, get_preview
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.runner import Runner
|
||||
|
||||
def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||
|
||||
def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
@@ -31,6 +33,7 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
|
||||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
|
||||
with gr.Box():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
start_btn.click(
|
||||
|
||||
36
src/llmtuner/webui/components/export.py
Normal file
36
src/llmtuner/webui/components/export.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
import gradio as gr
|
||||
|
||||
from llmtuner.webui.utils import export_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
save_dir = gr.Textbox()
|
||||
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
||||
|
||||
export_btn = gr.Button()
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
export_btn.click(
|
||||
export_model,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
max_shard_size,
|
||||
save_dir
|
||||
],
|
||||
[info_box]
|
||||
)
|
||||
|
||||
return dict(
|
||||
save_dir=save_dir,
|
||||
max_shard_size=max_shard_size,
|
||||
export_btn=export_btn,
|
||||
info_box=info_box
|
||||
)
|
||||
@@ -1,18 +1,20 @@
|
||||
from typing import Dict
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.components.chatbot import create_chat_box
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
|
||||
|
||||
def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
load_btn = gr.Button()
|
||||
unload_btn = gr.Button()
|
||||
|
||||
info_box = gr.Markdown()
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
chat_model = WebChatModel()
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from typing import Dict
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.runner import Runner
|
||||
|
||||
def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
|
||||
|
||||
def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
@@ -35,20 +37,31 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
||||
lr_scheduler_type = gr.Dropdown(
|
||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
|
||||
)
|
||||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
fp16 = gr.Checkbox(value=True)
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
|
||||
with gr.Row():
|
||||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
output_dir = gr.Textbox(interactive=True)
|
||||
with gr.Column(scale=3):
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
with gr.Box():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
with gr.Column(scale=1):
|
||||
@@ -74,10 +87,15 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
||||
batch_size,
|
||||
gradient_accumulation_steps,
|
||||
lr_scheduler_type,
|
||||
max_grad_norm,
|
||||
dev_ratio,
|
||||
fp16,
|
||||
logging_steps,
|
||||
save_steps,
|
||||
warmup_steps,
|
||||
compute_type,
|
||||
lora_rank,
|
||||
lora_dropout,
|
||||
lora_target,
|
||||
output_dir
|
||||
],
|
||||
[output_box]
|
||||
@@ -103,10 +121,17 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str,
|
||||
batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
max_grad_norm=max_grad_norm,
|
||||
dev_ratio=dev_ratio,
|
||||
fp16=fp16,
|
||||
advanced_tab=advanced_tab,
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
compute_type=compute_type,
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_dir=output_dir,
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
from typing import Dict
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
|
||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from llmtuner.extras.template import templates
|
||||
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
|
||||
from llmtuner.webui.utils import can_quantize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
def create_top() -> Dict[str, Component]:
|
||||
|
||||
def create_top() -> Dict[str, "Component"]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
with gr.Row():
|
||||
@@ -22,10 +24,11 @@ def create_top() -> Dict[str, Component]:
|
||||
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown([8, 4], scale=1)
|
||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=2)
|
||||
source_prefix = gr.Textbox(scale=4)
|
||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1)
|
||||
source_prefix = gr.Textbox(scale=2)
|
||||
|
||||
model_name.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
@@ -47,9 +50,10 @@ def create_top() -> Dict[str, Component]:
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
checkpoints=checkpoints,
|
||||
refresh_btn=refresh_btn,
|
||||
advanced_tab=advanced_tab,
|
||||
quantization_bit=quantization_bit,
|
||||
template=template,
|
||||
source_prefix=source_prefix
|
||||
)
|
||||
|
||||
@@ -5,7 +5,8 @@ from llmtuner.webui.components import (
|
||||
create_top,
|
||||
create_sft_tab,
|
||||
create_eval_tab,
|
||||
create_infer_tab
|
||||
create_infer_tab,
|
||||
create_export_tab
|
||||
)
|
||||
from llmtuner.webui.css import CSS
|
||||
from llmtuner.webui.manager import Manager
|
||||
@@ -27,10 +28,13 @@ def create_ui() -> gr.Blocks:
|
||||
with gr.Tab("Evaluate"):
|
||||
eval_elems = create_eval_tab(top_elems, runner)
|
||||
|
||||
with gr.Tab("Inference"):
|
||||
with gr.Tab("Chat"):
|
||||
infer_elems = create_infer_tab(top_elems)
|
||||
|
||||
elem_list = [top_elems, sft_elems, eval_elems, infer_elems]
|
||||
with gr.Tab("Export"):
|
||||
export_elems = create_export_tab(top_elems)
|
||||
|
||||
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
|
||||
manager = Manager(elem_list)
|
||||
|
||||
demo.load(
|
||||
@@ -51,4 +55,4 @@ def create_ui() -> gr.Blocks:
|
||||
if __name__ == "__main__":
|
||||
demo = create_ui()
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
|
||||
|
||||
@@ -49,6 +49,14 @@ LOCALES = {
|
||||
"value": "刷新断点"
|
||||
}
|
||||
},
|
||||
"advanced_tab": {
|
||||
"en": {
|
||||
"label": "Advanced configurations"
|
||||
},
|
||||
"zh": {
|
||||
"label": "高级设置"
|
||||
}
|
||||
},
|
||||
"quantization_bit": {
|
||||
"en": {
|
||||
"label": "Quantization bit (optional)",
|
||||
@@ -71,12 +79,12 @@ LOCALES = {
|
||||
},
|
||||
"source_prefix": {
|
||||
"en": {
|
||||
"label": "Source prefix (optional)",
|
||||
"info": "A sequence used as the prefix of each samples."
|
||||
"label": "System prompt (optional)",
|
||||
"info": "A sequence used as the default system prompt."
|
||||
},
|
||||
"zh": {
|
||||
"label": "前缀序列(非必填)",
|
||||
"info": "作为每个输入样本前缀的序列"
|
||||
"label": "系统提示词(非必填)",
|
||||
"info": "默认使用的系统提示词"
|
||||
}
|
||||
},
|
||||
"dataset_dir": {
|
||||
@@ -209,6 +217,16 @@ LOCALES = {
|
||||
"info": "采用的学习率调节器名称。"
|
||||
}
|
||||
},
|
||||
"max_grad_norm": {
|
||||
"en": {
|
||||
"label": "Maximum gradient norm",
|
||||
"info": "Norm for gradient clipping.."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大梯度范数",
|
||||
"info": "用于梯度裁剪的范数。"
|
||||
}
|
||||
},
|
||||
"dev_ratio": {
|
||||
"en": {
|
||||
"label": "Dev ratio",
|
||||
@@ -219,20 +237,10 @@ LOCALES = {
|
||||
"info": "验证集占全部样本的百分比。"
|
||||
}
|
||||
},
|
||||
"fp16": {
|
||||
"en": {
|
||||
"label": "fp16",
|
||||
"info": "Whether to use fp16 mixed precision training."
|
||||
},
|
||||
"zh": {
|
||||
"label": "fp16",
|
||||
"info": "是否启用 FP16 混合精度训练。"
|
||||
}
|
||||
},
|
||||
"logging_steps": {
|
||||
"en": {
|
||||
"label": "Logging steps",
|
||||
"info": "Number of update steps between two logs."
|
||||
"info": "Number of steps between two logs."
|
||||
},
|
||||
"zh": {
|
||||
"label": "日志间隔",
|
||||
@@ -242,13 +250,71 @@ LOCALES = {
|
||||
"save_steps": {
|
||||
"en": {
|
||||
"label": "Save steps",
|
||||
"info": "Number of updates steps between two checkpoints."
|
||||
"info": "Number of steps between two checkpoints."
|
||||
},
|
||||
"zh": {
|
||||
"label": "保存间隔",
|
||||
"info": "每两次断点保存间的更新步数。"
|
||||
}
|
||||
},
|
||||
"warmup_steps": {
|
||||
"en": {
|
||||
"label": "Warmup steps",
|
||||
"info": "Number of steps used for warmup."
|
||||
},
|
||||
"zh": {
|
||||
"label": "预热步数",
|
||||
"info": "学习率预热采用的步数。"
|
||||
}
|
||||
},
|
||||
"compute_type": {
|
||||
"en": {
|
||||
"label": "Compute type",
|
||||
"info": "Whether to use fp16 or bf16 mixed precision training."
|
||||
},
|
||||
"zh": {
|
||||
"label": "计算类型",
|
||||
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
||||
}
|
||||
},
|
||||
"lora_tab": {
|
||||
"en": {
|
||||
"label": "LoRA configurations"
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 参数设置"
|
||||
}
|
||||
},
|
||||
"lora_rank": {
|
||||
"en": {
|
||||
"label": "LoRA rank",
|
||||
"info": "The rank of LoRA matrices."
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 秩",
|
||||
"info": "LoRA 矩阵的秩。"
|
||||
}
|
||||
},
|
||||
"lora_dropout": {
|
||||
"en": {
|
||||
"label": "LoRA Dropout",
|
||||
"info": "Dropout ratio of LoRA weights."
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 随机丢弃",
|
||||
"info": "LoRA 权重随机丢弃的概率。"
|
||||
}
|
||||
},
|
||||
"lora_target": {
|
||||
"en": {
|
||||
"label": "LoRA modules (optional)",
|
||||
"info": "The name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 作用层(非必填)",
|
||||
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
|
||||
}
|
||||
},
|
||||
"start_btn": {
|
||||
"en": {
|
||||
"value": "Start"
|
||||
@@ -323,6 +389,14 @@ LOCALES = {
|
||||
"value": "模型未加载,请先加载模型。"
|
||||
}
|
||||
},
|
||||
"prefix": {
|
||||
"en": {
|
||||
"placeholder": "System prompt (optional)"
|
||||
},
|
||||
"zh": {
|
||||
"placeholder": "系统提示词(非必填)"
|
||||
}
|
||||
},
|
||||
"query": {
|
||||
"en": {
|
||||
"placeholder": "Input..."
|
||||
@@ -378,6 +452,34 @@ LOCALES = {
|
||||
"zh": {
|
||||
"label": "温度系数"
|
||||
}
|
||||
},
|
||||
"save_dir": {
|
||||
"en": {
|
||||
"label": "Export dir",
|
||||
"info": "Directory to save exported model."
|
||||
},
|
||||
"zh": {
|
||||
"label": "导出目录",
|
||||
"info": "保存导出模型的文件夹路径。"
|
||||
}
|
||||
},
|
||||
"max_shard_size": {
|
||||
"en": {
|
||||
"label": "Max shard size (GB)",
|
||||
"info": "The maximum size for a model file."
|
||||
},
|
||||
"zh": {
|
||||
"label": "最大分块大小(GB)",
|
||||
"info": "模型文件的最大大小。"
|
||||
}
|
||||
},
|
||||
"export_btn": {
|
||||
"en": {
|
||||
"value": "Export"
|
||||
},
|
||||
"zh": {
|
||||
"value": "开始导出"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -403,6 +505,14 @@ ALERTS = {
|
||||
"en": "Please choose a dataset.",
|
||||
"zh": "请选择数据集。"
|
||||
},
|
||||
"err_no_checkpoint": {
|
||||
"en": "Please select a checkpoint.",
|
||||
"zh": "请选择断点。"
|
||||
},
|
||||
"err_no_save_dir": {
|
||||
"en": "Please provide export dir.",
|
||||
"zh": "请填写导出目录"
|
||||
},
|
||||
"info_aborting": {
|
||||
"en": "Aborted, wait for terminating...",
|
||||
"zh": "训练中断,正在等待线程结束……"
|
||||
@@ -430,5 +540,13 @@ ALERTS = {
|
||||
"info_unloaded": {
|
||||
"en": "Model unloaded.",
|
||||
"zh": "模型已卸载。"
|
||||
},
|
||||
"info_exporting": {
|
||||
"en": "Exporting model...",
|
||||
"zh": "正在导出模型……"
|
||||
},
|
||||
"info_exported": {
|
||||
"en": "Model exported.",
|
||||
"zh": "模型导出完成。"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import gradio as gr
|
||||
from typing import Any, Dict, List
|
||||
from gradio.components import Component
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from llmtuner.webui.common import get_model_path, list_dataset, load_config
|
||||
from llmtuner.webui.locales import LOCALES
|
||||
@@ -24,7 +24,7 @@ class Manager:
|
||||
|
||||
return refresh_dict
|
||||
|
||||
def gen_label(self, lang: str) -> Dict[Component, dict]:
|
||||
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
|
||||
update_dict = {}
|
||||
refresh_dict = self.gen_refresh()
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@ import os
|
||||
import threading
|
||||
import time
|
||||
import transformers
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE # will be deprecated
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import get_train_args, run_sft
|
||||
@@ -25,7 +25,9 @@ class Runner:
|
||||
self.aborted = True
|
||||
self.running = False
|
||||
|
||||
def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]:
|
||||
def initialize(
|
||||
self, lang: str, model_name: str, dataset: List[str]
|
||||
) -> Tuple[str, str, LoggerHandler, LogCallback]:
|
||||
if self.running:
|
||||
return None, ALERTS["err_conflict"][lang], None, None
|
||||
|
||||
@@ -50,7 +52,9 @@ class Runner:
|
||||
|
||||
return model_name_or_path, "", logger_handler, trainer_callback
|
||||
|
||||
def finalize(self, lang: str, finish_info: Optional[str] = None) -> str:
|
||||
def finalize(
|
||||
self, lang: str, finish_info: Optional[str] = None
|
||||
) -> str:
|
||||
self.running = False
|
||||
torch_gc()
|
||||
if self.aborted:
|
||||
@@ -77,12 +81,17 @@ class Runner:
|
||||
batch_size: int,
|
||||
gradient_accumulation_steps: int,
|
||||
lr_scheduler_type: str,
|
||||
max_grad_norm: str,
|
||||
dev_ratio: float,
|
||||
fp16: bool,
|
||||
logging_steps: int,
|
||||
save_steps: int,
|
||||
warmup_steps: int,
|
||||
compute_type: str,
|
||||
lora_rank: int,
|
||||
lora_dropout: float,
|
||||
lora_target: str,
|
||||
output_dir: str
|
||||
):
|
||||
) -> Generator[str, None, None]:
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
@@ -99,11 +108,10 @@ class Runner:
|
||||
model_name_or_path=model_name_or_path,
|
||||
do_train=True,
|
||||
overwrite_cache=True,
|
||||
lora_target=DEFAULT_MODULE.get(model_name.split("-")[0], None) or "q_proj,v_proj",
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
prompt_template=template,
|
||||
template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
@@ -115,9 +123,15 @@ class Runner:
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
fp16=fp16,
|
||||
max_grad_norm=float(max_grad_norm),
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
fp16=(compute_type == "fp16"),
|
||||
bf16=(compute_type == "bf16"),
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||
)
|
||||
|
||||
@@ -164,7 +178,7 @@ class Runner:
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
predict: bool
|
||||
):
|
||||
) -> Generator[str, None, None]:
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
@@ -187,7 +201,7 @@ class Runner:
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
prompt_template=template,
|
||||
template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
|
||||
@@ -3,11 +3,13 @@ import json
|
||||
import gradio as gr
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from llmtuner.extras.ploting import smooth
|
||||
from llmtuner.webui.common import get_save_dir, DATA_CONFIG
|
||||
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
|
||||
def format_info(log: str, tracker: dict) -> str:
|
||||
@@ -83,3 +85,41 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
||||
ax.set_xlabel("step")
|
||||
ax.set_ylabel("loss")
|
||||
return fig
|
||||
|
||||
|
||||
def export_model(
|
||||
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
if not model_name:
|
||||
yield ALERTS["err_no_model"][lang]
|
||||
return
|
||||
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
if not model_name_or_path:
|
||||
yield ALERTS["err_no_path"][lang]
|
||||
return
|
||||
|
||||
if not checkpoints:
|
||||
yield ALERTS["err_no_checkpoint"][lang]
|
||||
return
|
||||
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
)
|
||||
|
||||
if not save_dir:
|
||||
yield ALERTS["err_no_save_dir"][lang]
|
||||
return
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from llmtuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
|
||||
from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from llmtuner import create_ui
|
||||
from llmtuner.webui.interface import create_ui
|
||||
|
||||
|
||||
def main():
|
||||
demo = create_ui()
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import gradio as gr
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from llmtuner import get_infer_args
|
||||
from llmtuner.tuner import get_infer_args
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.components.chatbot import create_chat_box
|
||||
from llmtuner.webui.manager import Manager
|
||||
@@ -24,20 +24,12 @@ def main():
|
||||
|
||||
manager = Manager([{"lang": lang}, chat_elems])
|
||||
|
||||
demo.load(
|
||||
manager.gen_label,
|
||||
[lang],
|
||||
[lang] + [elem for elem in chat_elems.values()],
|
||||
)
|
||||
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
||||
|
||||
lang.change(
|
||||
manager.gen_label,
|
||||
[lang],
|
||||
[lang] + [elem for elem in chat_elems.values()],
|
||||
)
|
||||
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
||||
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", share=False, inbrowser=True)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user