support FlashAttention2
Former-commit-id: 23e56c5554b948d4f08ad87849b261eafd2c7890
This commit is contained in:
55
README.md
55
README.md
@@ -12,11 +12,13 @@
|
||||
|
||||
## Changelog
|
||||
|
||||
[23/09/10] Now we support using **[FlashAttention](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs (experimental feature).
|
||||
|
||||
[23/08/18] Now we support **resuming training**, upgrade `transformers` to `4.31.0` to enjoy this feature.
|
||||
|
||||
[23/08/12] Now we support **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
|
||||
|
||||
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
|
||||
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models.
|
||||
|
||||
[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.
|
||||
|
||||
@@ -62,8 +64,11 @@
|
||||
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | xverse |
|
||||
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
||||
|
||||
- **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
|
||||
- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
|
||||
> **Note**
|
||||
>
|
||||
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
|
||||
>
|
||||
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
|
||||
|
||||
## Supported Training Approaches
|
||||
|
||||
@@ -75,7 +80,9 @@
|
||||
| PPO Training | | | :white_check_mark: | :white_check_mark: |
|
||||
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
- Use `--quantization_bit 4/8` argument to enable QLoRA.
|
||||
> **Note**
|
||||
>
|
||||
> Use `--quantization_bit 4/8` argument to enable QLoRA.
|
||||
|
||||
## Provided Datasets
|
||||
|
||||
@@ -138,7 +145,9 @@ And **powerful GPUs**!
|
||||
|
||||
Please refer to `data/example_dataset` for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
|
||||
|
||||
Note: please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
|
||||
> **Note**
|
||||
>
|
||||
> Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
|
||||
|
||||
### Dependence Installation (optional)
|
||||
|
||||
@@ -164,10 +173,16 @@ CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
||||
|
||||
We strongly recommend using the all-in-one Web UI for newcomers since it can also generate training scripts **automatically**.
|
||||
|
||||
Currently the web UI only supports training on **a single GPU**.
|
||||
> **Warning**
|
||||
>
|
||||
> Currently the web UI only supports training on **a single GPU**.
|
||||
|
||||
### Train on a single GPU
|
||||
|
||||
> **Warning**
|
||||
>
|
||||
> If you want to train models on multiple GPUs, please refer to [#distributed-training](Distributed Training).
|
||||
|
||||
#### Pre-Training
|
||||
|
||||
```bash
|
||||
@@ -300,19 +315,13 @@ accelerate config # configure the environment
|
||||
accelerate launch src/train_bash.py # arguments (same as above)
|
||||
```
|
||||
|
||||
<details><summary>Example config.yaml for training with DeepSpeed ZeRO-2</summary>
|
||||
<details><summary>Example config for LoRA training</summary>
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_clipping: 0.5
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
@@ -336,7 +345,7 @@ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||
... # arguments (same as above)
|
||||
```
|
||||
|
||||
<details><summary>Example ds_config.json for training with DeepSpeed ZeRO-2</summary>
|
||||
<details><summary>Example config for full-parameter training with DeepSpeed ZeRO-2</summary>
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -387,7 +396,9 @@ python src/api_demo.py \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
Visit `http://localhost:8000/docs` for API documentation.
|
||||
> **Note**
|
||||
>
|
||||
> Visit `http://localhost:8000/docs` for API documentation.
|
||||
|
||||
### CLI Demo
|
||||
|
||||
@@ -426,7 +437,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
|
||||
> **Note**
|
||||
>
|
||||
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
|
||||
|
||||
### Predict
|
||||
|
||||
@@ -445,12 +458,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)).
|
||||
- [ ] Implementing multi-query attention for faster inference.
|
||||
- [ ] Supporting full-parameter RLHF training.
|
||||
|
||||
## License
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Reference in New Issue
Block a user