support FlashAttention2

Former-commit-id: 23e56c5554b948d4f08ad87849b261eafd2c7890
This commit is contained in:
hiyouga
2023-09-10 20:43:56 +08:00
parent b481ad58e6
commit a402161631
9 changed files with 875 additions and 115 deletions

View File

@@ -12,11 +12,13 @@
## Changelog
[23/09/10] Now we support using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs (experimental feature).
[23/08/18] Now we support **resuming training**, upgrade `transformers` to `4.31.0` to enjoy this feature.
[23/08/12] Now we support **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models.
[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.
@@ -62,8 +64,11 @@
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
- **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
> **Note**
>
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
>
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
## Supported Training Approaches
@@ -75,7 +80,9 @@
| PPO Training | | | :white_check_mark: | :white_check_mark: |
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
- Use `--quantization_bit 4/8` argument to enable QLoRA.
> **Note**
>
> Use `--quantization_bit 4/8` argument to enable QLoRA.
## Provided Datasets
@@ -138,7 +145,9 @@ And **powerful GPUs**!
Please refer to `data/example_dataset` for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
Note: please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
> **Note**
>
> Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
### Dependence Installation (optional)
@@ -164,10 +173,16 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
We strongly recommend using the all-in-one Web UI for newcomers since it can also generate training scripts **automatically**.
Currently the web UI only supports training on **a single GPU**.
> **Warning**
>
> Currently the web UI only supports training on **a single GPU**.
### Train on a single GPU
> **Warning**
>
> If you want to train models on multiple GPUs, please refer to [#distributed-training](Distributed Training).
#### Pre-Training
```bash
@@ -300,19 +315,13 @@ accelerate config # configure the environment
accelerate launch src/train_bash.py # arguments (same as above)
```
<details><summary>Example config.yaml for training with DeepSpeed ZeRO-2</summary>
<details><summary>Example config for LoRA training</summary>
```yaml
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 4
gradient_clipping: 0.5
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
@@ -336,7 +345,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
... # arguments (same as above)
```
<details><summary>Example ds_config.json for training with DeepSpeed ZeRO-2</summary>
<details><summary>Example config for full-parameter training with DeepSpeed ZeRO-2</summary>
```json
{
@@ -387,7 +396,9 @@ python src/api_demo.py \
--checkpoint_dir path_to_checkpoint
```
Visit `http://localhost:8000/docs` for API documentation.
> **Note**
>
> Visit `http://localhost:8000/docs` for API documentation.
### CLI Demo
@@ -426,7 +437,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--predict_with_generate
```
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
> **Note**
>
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
### Predict
@@ -445,12 +458,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--predict_with_generate
```
## TODO
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)).
- [ ] Implementing multi-query attention for faster inference.
- [ ] Supporting full-parameter RLHF training.
## License
This repository is licensed under the [Apache-2.0 License](LICENSE).