update ppo trainer

Former-commit-id: caa525a5c6f228b9ad71387d1fe4f1c2ffa2479e
This commit is contained in:
hiyouga
2023-11-20 21:39:15 +08:00
parent e585950c54
commit 28258aecd2
7 changed files with 68 additions and 41 deletions

View File

@@ -319,9 +319,13 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss
--plot_loss \
--fp16
```
> [!WARNING]
> 如果在 fp16 精度下训练 LLaMA-2 模型,请使用 `--per_device_eval_batch_size=1`。
#### DPO 训练
```bash
@@ -417,7 +421,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
</details>
### 导出微调后的完整模型
### 合并 LoRA 权重并导出完整模型
```bash
python src/export_model.py \
@@ -438,7 +442,7 @@ python src/api_demo.py \
--checkpoint_dir path_to_checkpoint
```
> [!NOTE]
> [!TIP]
> 关于 API 文档请见 `http://localhost:8000/docs`。
### 命令行测试
@@ -490,10 +494,14 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
--predict_with_generate \
--fp16
```
> [!NOTE]
> [!WARNING]
> 如果在 fp16 精度下推理 LLaMA-2 模型,请使用 `--per_device_eval_batch_size=1`。
> [!TIP]
> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
## 使用了 LLaMA Factory 的项目
@@ -503,7 +511,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
> [!NOTE]
> [!TIP]
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
## 协议