support DPO training (2305.18290)

Former-commit-id: 6d98de148e4af63a7028dfaeb6cf86eb56a4488f
This commit is contained in:
hiyouga
2023-08-11 03:02:53 +08:00
parent 72dfd74005
commit ca719a8697
33 changed files with 513 additions and 192 deletions

View File

@@ -12,7 +12,9 @@
## 更新日志
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat``--lora_target c_attn` 参数。请注意使用 Qwen-7B-Chat 模型需要添加 `--template chatml` 参数
[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-training)(实验性功能)
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat``--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型请添加 `--template chatml` 参数。
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming``--max_steps 100` 参数来流式加载数据集。
@@ -54,41 +56,34 @@
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
> * **默认模块**是 `--lora_target` 参数的默认值。请使用 `python src/train_bash.py -h` 查看全部可选项。
> * 对于所有“基座”模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等
- **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。
- 对于所有“基座”Base模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna`任意值。但“对话”Chat模型请务必使用对应的模板
## 微调方法
## 训练方法
- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
- 全参数微调
- 部分参数微调
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [指令监督微调](https://arxiv.org/abs/2109.01652)
- 全参数微调
- 部分参数微调
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [人类反馈的强化学习RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
| ---------- | ---------- | ----------- | ---- | ----- |
| 预训练 | ✅ | ✅ | ✅ | ✅ |
| 指令监督微调 | ✅ | ✅ | ✅ | ✅ |
| 奖励模型训练 | | | ✅ | ✅ |
| PPO 训练 | | | ✅ | ✅ |
| DPO 训练 | ✅ | | ✅ | ✅ |
## 数据集
- 用于二次预训练:
- 用于预训练
- [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- 用于指令监督微调:
- 用于指令监督微调
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@@ -103,7 +98,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- 用于奖励模型训练:
- 用于奖励模型或 DPO 训练
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
@@ -139,7 +134,6 @@ huggingface-cli login
### 环境搭建(可跳过)
```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10
conda activate llama_etuning
@@ -161,7 +155,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
目前网页 UI 仅支持**单卡训练**。
### 二次预训练
### 预训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@@ -222,7 +216,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
@@ -233,7 +227,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16
```
### RLHF 训练
### PPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
@@ -257,6 +251,30 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--plot_loss
```
### DPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### 多 GPU 分布式训练
```bash