Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48d2e6d7fe | ||
|
|
041c83ea03 | ||
|
|
0e621c2dc9 | ||
|
|
544e7a491b | ||
|
|
a2c881fa08 | ||
|
|
c53c7af168 | ||
|
|
a2d93e5269 | ||
|
|
b392e6cfb9 | ||
|
|
13aa2d389a | ||
|
|
1e7962dfc4 | ||
|
|
1c9556c84c | ||
|
|
ca3ca7a5b5 | ||
|
|
0500befdb4 | ||
|
|
f618feab51 | ||
|
|
4b06aa134f | ||
|
|
9cde56d760 | ||
|
|
d0ea203694 | ||
|
|
c5eb3fba62 | ||
|
|
a8bc32553c | ||
|
|
88f3358320 | ||
|
|
a85bdcf2f6 | ||
|
|
caf56b313e | ||
|
|
75603c45fc | ||
|
|
89f86cc970 | ||
|
|
c09a0e4f08 | ||
|
|
7bac6c9460 | ||
|
|
0c7d0bf172 | ||
|
|
a274900188 | ||
|
|
67deefe527 | ||
|
|
823f618cba | ||
|
|
bc16c9a54a | ||
|
|
a3f30038a0 | ||
|
|
e237f618c2 | ||
|
|
688adad665 | ||
|
|
0158812afb | ||
|
|
e52e0d9b07 | ||
|
|
eb2aa2c073 |
7
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
7
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
# What does this PR do?
|
||||
|
||||
Fixes # (issue)
|
||||
|
||||
## Before submitting
|
||||
|
||||
- [ ] Did you read the [contributor guideline](/CONTRIBUTING.md)?
|
||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install black ruff
|
||||
python -m pip install ruff
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
|
||||
21
CONTRIBUTING.md
Normal file
21
CONTRIBUTING.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# Contributing to LLaMA Factory
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
|
||||
|
||||
It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
|
||||
|
||||
However you choose to contribute, please be mindful and respect our [code of conduct](CODE_OF_CONDUCT.md).
|
||||
|
||||
**This guide was heavily inspired by [transformers guide to contributing](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md).**
|
||||
|
||||
## Ways to contribute
|
||||
|
||||
There are several ways you can contribute to LLaMA Factory:
|
||||
|
||||
* Fix outstanding issues with the existing code.
|
||||
* Submit issues related to bugs or desired new features.
|
||||
* Contribute to the examples or to the documentation.
|
||||
|
||||
### Style guide
|
||||
|
||||
LLaMA Factory follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html), check it for details.
|
||||
4
Makefile
4
Makefile
@@ -3,9 +3,9 @@
|
||||
check_dirs := src tests
|
||||
|
||||
quality:
|
||||
black --check $(check_dirs)
|
||||
ruff $(check_dirs)
|
||||
ruff format --check $(check_dirs)
|
||||
|
||||
style:
|
||||
black $(check_dirs)
|
||||
ruff $(check_dirs) --fix
|
||||
ruff format $(check_dirs)
|
||||
|
||||
97
README.md
97
README.md
@@ -5,6 +5,7 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](#projects-using-llama-factory)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
@@ -16,9 +17,7 @@
|
||||
|
||||
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
|
||||
|
||||
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** or **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)**.
|
||||
|
||||
Launch LLaMA Board via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet in this mode)
|
||||
Preview LLaMA Board at **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** and **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)**, or launch it locally with `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`.
|
||||
|
||||
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
|
||||
|
||||
@@ -26,6 +25,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Features](#features)
|
||||
- [Benchmark](#benchmark)
|
||||
- [Changelog](#changelog)
|
||||
- [Supported Models](#supported-models)
|
||||
@@ -38,6 +38,15 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [Citation](#citation)
|
||||
- [Acknowledgement](#acknowledgement)
|
||||
|
||||
## Features
|
||||
|
||||
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO and DPO.
|
||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA, 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||
- **Advanced algorithms**: DoRA, LongLoRA, LLaMA Pro, LoftQ, agent tuning.
|
||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune, rsLoRA.
|
||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||
|
||||
## Benchmark
|
||||
|
||||
Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning), LLaMA-Factory's LoRA tuning offers up to **3.7 times faster** training speed with a better Rouge score on the advertising text generation task. By leveraging 4-bit quantization technique, LLaMA-Factory's QLoRA further improves the efficiency regarding the GPU memory.
|
||||
@@ -55,14 +64,16 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
|
||||
## Changelog
|
||||
|
||||
[24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training.
|
||||
|
||||
[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `tests/llama_pro.py` for usage.
|
||||
|
||||
[24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details.
|
||||
|
||||
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
[24/01/18] We supported **agent tuning** for most models, equipping model with tool using abilities by fine-tuning with `--dataset glaive_toolcall`.
|
||||
|
||||
[23/12/23] We supported **[unsloth](https://github.com/unslothai/unsloth)**'s implementation to boost LoRA tuning for the LLaMA, Mistral and Yi models. Try `--use_unsloth` argument to activate unsloth patch. It achieves 1.7x speed in our benchmark, check [this page](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison) for details.
|
||||
|
||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||
@@ -107,6 +118,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
@@ -128,7 +140,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
||||
|
||||
## Supported Training Approaches
|
||||
|
||||
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
|
||||
| Approach | Full-tuning | Freeze-tuning | LoRA | QLoRA |
|
||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
@@ -225,15 +237,27 @@ huggingface-cli login
|
||||
|
||||
## Requirement
|
||||
|
||||
- Python 3.8+ and PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||
- sentencepiece, protobuf and tiktoken
|
||||
- jieba, rouge-chinese and nltk (used at evaluation and predict)
|
||||
- gradio and matplotlib (used in web UI)
|
||||
- uvicorn, fastapi and sse-starlette (used in API)
|
||||
| Mandatory | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.8 | 3.10 |
|
||||
| torch | 1.13.1 | 2.2.1 |
|
||||
| transformers | 4.37.2 | 4.38.1 |
|
||||
| datasets | 2.14.3 | 2.17.1 |
|
||||
| accelerate | 0.27.2 | 0.27.2 |
|
||||
| peft | 0.9.0 | 0.9.0 |
|
||||
| trl | 0.7.11 | 0.7.11 |
|
||||
|
||||
| Optional | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| CUDA | 11.6 | 12.2 |
|
||||
| deepspeed | 0.10.0 | 0.13.4 |
|
||||
| bitsandbytes | 0.39.0 | 0.41.3 |
|
||||
| flash-attn | 2.3.0 | 2.5.5 |
|
||||
|
||||
### Hardware Requirement
|
||||
|
||||
\* *estimated*
|
||||
|
||||
| Method | Bits | 7B | 13B | 30B | 65B | 8x7B |
|
||||
| ------ | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| Full | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
|
||||
@@ -261,12 +285,14 @@ cd LLaMA-Factory
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.1.
|
||||
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you will be required to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2.
|
||||
|
||||
```bash
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.0-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
|
||||
|
||||
### Use ModelScope Hub (optional)
|
||||
|
||||
If you have trouble with downloading models and datasets from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
|
||||
@@ -394,6 +420,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` to infer the fine-tuned model.
|
||||
|
||||
> [!WARNING]
|
||||
> Use `--per_device_train_batch_size=1` for LLaMA-2 models in fp16 PPO training.
|
||||
|
||||
@@ -422,6 +451,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Use `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` to infer the fine-tuned model.
|
||||
|
||||
### Distributed Training
|
||||
|
||||
#### Use Huggingface Accelerate
|
||||
@@ -435,6 +467,7 @@ accelerate launch src/train_bash.py # arguments (same as above)
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
@@ -511,7 +544,7 @@ python src/export_model.py \
|
||||
> [!TIP]
|
||||
> Use `--export_quantization_bit 4` and `--export_quantization_dataset data/c4_demo.json` to quantize the model after merging the LoRA weights.
|
||||
|
||||
### API Demo
|
||||
### Inference with OpenAI-style API
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
@@ -524,7 +557,7 @@ python src/api_demo.py \
|
||||
> [!TIP]
|
||||
> Visit `http://localhost:8000/docs` for API documentation.
|
||||
|
||||
### CLI Demo
|
||||
### Inference with command line
|
||||
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
@@ -534,7 +567,7 @@ python src/cli_demo.py \
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
### Web Demo
|
||||
### Inference with web browser
|
||||
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
@@ -571,7 +604,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--output_dir path_to_predict_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate \
|
||||
--fp16
|
||||
@@ -585,11 +618,27 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
||||
## Projects using LLaMA Factory
|
||||
|
||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||
- **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
|
||||
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
|
||||
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**: A series of MBTI Personality large language models, capable of giving any LLM 16 different personality types based on different datasets and training methods.
|
||||
|
||||
> [!TIP]
|
||||
> If you have a project that should be incorporated, please contact via email or create a pull request.
|
||||
@@ -598,7 +647,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
95
README_zh.md
95
README_zh.md
@@ -5,6 +5,7 @@
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](#使用了-llama-factory-的项目)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/rKfvV9r9FK)
|
||||
[](https://huggingface.co/spaces/hiyouga/LLaMA-Board)
|
||||
@@ -16,9 +17,7 @@
|
||||
|
||||
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
|
||||
|
||||
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 或 **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)** 预览 LLaMA Board。
|
||||
|
||||
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 LLaMA Board。(该模式目前仅支持单卡训练)
|
||||
通过 **[🤗 Spaces](https://huggingface.co/spaces/hiyouga/LLaMA-Board)** 或 **[ModelScope](https://modelscope.cn/studios/hiyouga/LLaMA-Board)** 预览 LLaMA Board,或者通过命令 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 本地启动。
|
||||
|
||||
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
|
||||
|
||||
@@ -26,6 +25,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
|
||||
## 目录
|
||||
|
||||
- [项目特色](#项目特色)
|
||||
- [性能指标](#性能指标)
|
||||
- [更新日志](#更新日志)
|
||||
- [模型](#模型)
|
||||
@@ -38,6 +38,15 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
- [引用](#引用)
|
||||
- [致谢](#致谢)
|
||||
|
||||
## 项目特色
|
||||
|
||||
- **多种模型**:LLaMA、Mistral、Mixtral-MoE、Qwen、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。
|
||||
- **集成方法**:(增量)预训练、指令监督微调、奖励模型训练、PPO 训练和 DPO 训练。
|
||||
- **多种精度**:32 比特全参数微调、16 比特冻结微调、16 比特 LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8 的 2/4/8 比特 QLoRA 微调。
|
||||
- **先进算法**:DoRA、LongLoRA、LLaMA Pro、LoftQ 和 Agent 微调。
|
||||
- **实用技巧**:FlashAttention-2、Unsloth、RoPE scaling、NEFTune 和 rsLoRA。
|
||||
- **实验监控**:LlamaBoard、TensorBoard、Wandb、MLflow 等等。
|
||||
|
||||
## 性能指标
|
||||
|
||||
与 ChatGLM 官方的 [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ptuning) 微调相比,LLaMA-Factory 的 LoRA 微调提供了 **3.7 倍**的加速比,同时在广告文案生成任务上取得了更高的 Rouge 分数。结合 4 比特量化技术,LLaMA-Factory 的 QLoRA 微调进一步降低了 GPU 显存消耗。
|
||||
@@ -55,14 +64,16 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
|
||||
## 更新日志
|
||||
|
||||
[24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。
|
||||
|
||||
[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `tests/llama_pro.py`。
|
||||
|
||||
[24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。
|
||||
|
||||
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
[24/01/18] 我们针对绝大多数模型实现了 **Agent 微调**,微调时指定 `--dataset glaive_toolcall` 即可使模型获得工具调用能力。
|
||||
|
||||
[23/12/23] 我们针对 LLaMA, Mistral 和 Yi 模型支持了 **[unsloth](https://github.com/unslothai/unsloth)** 的 LoRA 训练加速。请使用 `--use_unsloth` 参数启用 unsloth 优化。该方法可提供 1.7 倍的训练速度,详情请查阅[此页面](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-comparison)。
|
||||
|
||||
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
||||
@@ -107,6 +118,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
||||
| [ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) | 6B | query_key_value | chatglm3 |
|
||||
| [DeepSeek (MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B | q_proj,v_proj | deepseek |
|
||||
| [Falcon](https://huggingface.co/tiiuae) | 7B/40B/180B | query_key_value | falcon |
|
||||
| [Gemma](https://huggingface.co/google) | 2B/7B | q_proj,v_proj | gemma |
|
||||
| [InternLM2](https://huggingface.co/internlm) | 7B/20B | wqkv | intern2 |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
@@ -225,15 +237,27 @@ huggingface-cli login
|
||||
|
||||
## 软硬件依赖
|
||||
|
||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
||||
- sentencepiece, protobuf 和 tiktoken
|
||||
- jieba, rouge-chinese 和 nltk (用于评估及预测)
|
||||
- gradio 和 matplotlib (用于网页端交互)
|
||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
||||
| 必需项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.8 | 3.10 |
|
||||
| torch | 1.13.1 | 2.2.1 |
|
||||
| transformers | 4.37.2 | 4.38.1 |
|
||||
| datasets | 2.14.3 | 2.17.1 |
|
||||
| accelerate | 0.27.2 | 0.27.2 |
|
||||
| peft | 0.9.0 | 0.9.0 |
|
||||
| trl | 0.7.11 | 0.7.11 |
|
||||
|
||||
| 可选项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| CUDA | 11.6 | 12.2 |
|
||||
| deepspeed | 0.10.0 | 0.13.4 |
|
||||
| bitsandbytes | 0.39.0 | 0.41.3 |
|
||||
| flash-attn | 2.3.0 | 2.5.5 |
|
||||
|
||||
### 硬件依赖
|
||||
|
||||
\* *估算值*
|
||||
|
||||
| 训练方法 | 精度 | 7B | 13B | 30B | 65B | 8x7B |
|
||||
| ------- | ---- | ----- | ----- | ----- | ------ | ------ |
|
||||
| 全参数 | 16 | 160GB | 320GB | 600GB | 1200GB | 900GB |
|
||||
@@ -261,12 +285,14 @@ cd LLaMA-Factory
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.1.
|
||||
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2。
|
||||
|
||||
```bash
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.0-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
如果要在 Windows 平台上开启 FlashAttention-2,需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。
|
||||
|
||||
### 使用魔搭社区(可跳过)
|
||||
|
||||
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
||||
@@ -394,6 +420,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_ppo_checkpoint` 来进行微调模型的推理。
|
||||
|
||||
> [!WARNING]
|
||||
> 如果使用 fp16 精度进行 LLaMA-2 模型的 PPO 训练,请使用 `--per_device_train_batch_size=1`。
|
||||
|
||||
@@ -422,6 +451,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> 使用 `--adapter_name_or_path path_to_sft_checkpoint,path_to_dpo_checkpoint` 来进行微调模型的推理。
|
||||
|
||||
### 多 GPU 分布式训练
|
||||
|
||||
#### 使用 Huggingface Accelerate
|
||||
@@ -435,6 +467,7 @@ accelerate launch src/train_bash.py # 参数同上
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
@@ -511,7 +544,7 @@ python src/export_model.py \
|
||||
> [!TIP]
|
||||
> 合并 LoRA 权重之后可再次使用 `--export_quantization_bit 4` 和 `--export_quantization_dataset data/c4_demo.json` 量化模型。
|
||||
|
||||
### API 服务
|
||||
### 使用 OpenAI 风格 API 推理
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
@@ -524,7 +557,7 @@ python src/api_demo.py \
|
||||
> [!TIP]
|
||||
> 关于 API 文档请见 `http://localhost:8000/docs`。
|
||||
|
||||
### 命令行测试
|
||||
### 使用命令行推理
|
||||
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
@@ -534,7 +567,7 @@ python src/cli_demo.py \
|
||||
--finetuning_type lora
|
||||
```
|
||||
|
||||
### 浏览器测试
|
||||
### 使用浏览器推理
|
||||
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
@@ -571,7 +604,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--output_dir path_to_predict_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate \
|
||||
--fp16
|
||||
@@ -585,11 +618,27 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
||||
## 使用了 LLaMA Factory 的项目
|
||||
|
||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||
- **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||
1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223)
|
||||
1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092)
|
||||
1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816)
|
||||
1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710)
|
||||
1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319)
|
||||
1. Wang et al. CANDLE: Iterative Conceptualization and Instantiation Distillation from Large Language Models for Commonsense Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2401.07286)
|
||||
1. Choi et al. FACT-GPT: Fact-Checking Augmentation via Claim Matching with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2402.05904)
|
||||
1. Zhang et al. AutoMathText: Autonomous Data Selection with Language Models for Mathematical Texts. 2024. [[arxiv]](https://arxiv.org/abs/2402.07625)
|
||||
1. Lyu et al. KnowTuning: Knowledge-aware Fine-tuning for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11176)
|
||||
1. Yang et al. LaCo: Large Language Model Pruning via Layer Collaps. 2024. [[arxiv]](https://arxiv.org/abs/2402.11187)
|
||||
1. Bhardwaj et al. Language Models are Homer Simpson! Safety Re-Alignment of Fine-tuned Language Models through Task Arithmetic. 2024. [[arxiv]](https://arxiv.org/abs/2402.11746)
|
||||
1. Yang et al. Enhancing Empathetic Response Generation by Augmenting LLMs with Small-scale Empathetic Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11801)
|
||||
1. Yi et al. Generation Meets Verification: Accelerating Large Language Model Inference with Smart Parallel Auto-Correct Decoding. 2024. [[arxiv]](https://arxiv.org/abs/2402.11809)
|
||||
1. Cao et al. Head-wise Shareable Attention for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.11819)
|
||||
1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204)
|
||||
1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714)
|
||||
1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||
1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
1. **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||
1. **[MachineMindset](https://github.com/PKU-YuanGroup/Machine-Mindset/)**:MBTI性格大模型项目,根据数据集与训练方式让任意 LLM 拥有 16 个不同的性格类型。
|
||||
|
||||
> [!TIP]
|
||||
> 如果您有项目希望添加至上述列表,请通过邮件联系或者创建一个 PR。
|
||||
@@ -598,7 +647,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/main/Community%20License%20for%20Baichuan%202%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [DeepSeek](https://github.com/deepseek-ai/DeepSeek-LLM/blob/main/LICENSE-MODEL) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [Gemma](https://ai.google.dev/gemma/terms) / [InternLM2](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5/2](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/Tongyi%20Qianwen%20LICENSE%20AGREEMENT) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf) / [Yi](https://huggingface.co/01-ai/Yi-6B/blob/main/LICENSE) / [Yuan](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/LICENSE-Yuan)
|
||||
|
||||
## 引用
|
||||
|
||||
|
||||
7
SECURITY.md
Normal file
7
SECURITY.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# Reporting Security Issues
|
||||
|
||||
To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/electron/electron/security/advisories/new) tab.
|
||||
|
||||
We will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance.
|
||||
|
||||
Report security bugs in third-party modules to the person or team maintaining the module.
|
||||
29
examples/full_multi_gpu/sft.sh
Normal file
29
examples/full_multi_gpu/sft.sh
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
|
||||
deepspeed --num_gpus 4 ../../src/train_bash.py \
|
||||
--deepspeed ds_z3_config.json \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type full \
|
||||
--output_dir ../../saves/LLaMA2-7B/full/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
16
examples/lora_multi_gpu/config.yaml
Normal file
16
examples/lora_multi_gpu/config.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
30
examples/lora_multi_gpu/sft.sh
Normal file
30
examples/lora_multi_gpu/sft.sh
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file config.yaml ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
33
examples/lora_single_gpu/dpo.sh
Normal file
33
examples/lora_single_gpu/dpo.sh
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage dpo \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--create_new_adapter \
|
||||
--dataset comparison_gpt4_en \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/dpo \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--max_samples 1000 \
|
||||
--val_size 0.1 \
|
||||
--dpo_ftx 1.0 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
31
examples/lora_single_gpu/ppo.sh
Normal file
31
examples/lora_single_gpu/ppo.sh
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage ppo \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--create_new_adapter \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--reward_model ../../saves/LLaMA2-7B/lora/reward \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/ppo \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 512 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--max_samples 1000 \
|
||||
--top_k 0 \
|
||||
--top_p 0.9 \
|
||||
--max_new_tokens 256 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
18
examples/lora_single_gpu/predict.sh
Normal file
18
examples/lora_single_gpu/predict.sh
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_predict \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft,../../saves/LLaMA2-7B/lora/dpo \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/predict \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--max_samples 20 \
|
||||
--predict_with_generate
|
||||
29
examples/lora_single_gpu/pretrain.sh
Normal file
29
examples/lora_single_gpu/pretrain.sh
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage pt \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset c4_demo \
|
||||
--dataset_dir ../../data \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/pretrain \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 10000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
31
examples/lora_single_gpu/reward.sh
Normal file
31
examples/lora_single_gpu/reward.sh
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage rm \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--adapter_name_or_path ../../saves/LLaMA2-7B/lora/sft \
|
||||
--create_new_adapter \
|
||||
--dataset comparison_gpt4_en \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/reward \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--max_samples 5000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
30
examples/lora_single_gpu/sft.sh
Normal file
30
examples/lora_single_gpu/sft.sh
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
30
examples/qlora_single_gpu/aqlm.sh
Normal file
30
examples/qlora_single_gpu/aqlm.sh
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path BlackSamorez/Llama-2-7b-AQLM-2Bit-1x16-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
30
examples/qlora_single_gpu/awq.sh
Normal file
30
examples/qlora_single_gpu/awq.sh
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path TheBloke/Llama-2-7B-AWQ \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
31
examples/qlora_single_gpu/bitsandbytes.sh
Normal file
31
examples/qlora_single_gpu/bitsandbytes.sh
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path meta-llama/Llama-2-7b-hf \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--quantization_bit 4 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
30
examples/qlora_single_gpu/gptq.sh
Normal file
30
examples/qlora_single_gpu/gptq.sh
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \
|
||||
--stage sft \
|
||||
--do_train \
|
||||
--model_name_or_path TheBloke/Llama-2-7B-GPTQ \
|
||||
--dataset alpaca_gpt4_en,glaive_toolcall \
|
||||
--dataset_dir ../../data \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir ../../saves/LLaMA2-7B/lora/sft \
|
||||
--overwrite_cache \
|
||||
--overwrite_output_dir \
|
||||
--cutoff_len 1024 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 100 \
|
||||
--eval_steps 100 \
|
||||
--evaluation_strategy steps \
|
||||
--load_best_model_at_end \
|
||||
--learning_rate 5e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--max_samples 3000 \
|
||||
--val_size 0.1 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
@@ -2,11 +2,8 @@
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.black]
|
||||
line-length = 119
|
||||
target-version = ["py38"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py38"
|
||||
line-length = 119
|
||||
indent-width = 4
|
||||
|
||||
@@ -17,17 +14,7 @@ select = ["C", "E", "F", "I", "W"]
|
||||
[tool.ruff.lint.isort]
|
||||
lines-after-imports = 2
|
||||
known-first-party = ["llmtuner"]
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
|
||||
[isort]
|
||||
default_section = "FIRSTPARTY"
|
||||
known_first_party = "llmtuner"
|
||||
known_third_party = [
|
||||
known-third-party = [
|
||||
"accelerate",
|
||||
"datasets",
|
||||
"gradio",
|
||||
@@ -37,10 +24,9 @@ known_third_party = [
|
||||
"transformers",
|
||||
"trl"
|
||||
]
|
||||
line_length = 119
|
||||
lines_after_imports = 2
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
ensure_newline_before_comments = true
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
torch>=1.13.1
|
||||
transformers>=4.37.2
|
||||
datasets>=2.14.3
|
||||
accelerate>=0.21.0
|
||||
peft>=0.8.2
|
||||
trl>=0.7.6
|
||||
accelerate>=0.27.2
|
||||
peft>=0.9.0
|
||||
trl>=0.7.11
|
||||
gradio>=3.38.0,<4.0.0
|
||||
scipy
|
||||
einops
|
||||
|
||||
@@ -7,5 +7,5 @@ from .train import export_model, run_exp
|
||||
from .webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
__version__ = "0.5.2"
|
||||
__version__ = "0.5.3"
|
||||
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]
|
||||
|
||||
@@ -75,11 +75,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
|
||||
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||
role_mapping = {
|
||||
Role.USER: DataRole.USER,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT,
|
||||
Role.SYSTEM: DataRole.SYSTEM,
|
||||
Role.FUNCTION: DataRole.FUNCTION,
|
||||
Role.TOOL: DataRole.OBSERVATION,
|
||||
Role.USER: DataRole.USER.value,
|
||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||
Role.SYSTEM: DataRole.SYSTEM.value,
|
||||
Role.FUNCTION: DataRole.FUNCTION.value,
|
||||
Role.TOOL: DataRole.OBSERVATION.value,
|
||||
}
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
@@ -95,7 +95,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
if len(request.messages) == 0:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||
|
||||
if role_mapping[request.messages[0].role] == DataRole.SYSTEM:
|
||||
if request.messages[0].role == Role.SYSTEM:
|
||||
system = request.messages.pop(0).content
|
||||
else:
|
||||
system = ""
|
||||
@@ -105,11 +105,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
|
||||
input_messages = []
|
||||
for i, message in enumerate(request.messages):
|
||||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
||||
if i % 2 == 0 and input_messages[i]["role"] not in [DataRole.USER, DataRole.OBSERVATION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
elif i % 2 == 1 and input_messages[i]["role"] not in [DataRole.ASSISTANT, DataRole.FUNCTION]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||
|
||||
tool_list = request.tools
|
||||
if isinstance(tool_list, list) and len(tool_list):
|
||||
|
||||
@@ -19,8 +19,8 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
||||
prompt = []
|
||||
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
||||
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
||||
prompt.append({"role": Role.USER, "content": old_prompt})
|
||||
prompt.append({"role": Role.ASSISTANT, "content": old_response})
|
||||
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
||||
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
||||
|
||||
content = []
|
||||
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
||||
@@ -29,12 +29,14 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
||||
if dataset_attr.query and examples[dataset_attr.query][i]:
|
||||
content.append(examples[dataset_attr.query][i])
|
||||
|
||||
prompt.append({"role": Role.USER, "content": "\n".join(content)})
|
||||
prompt.append({"role": Role.USER.value, "content": "\n".join(content)})
|
||||
|
||||
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], list):
|
||||
response = [{"role": Role.ASSISTANT, "content": content} for content in examples[dataset_attr.response][i]]
|
||||
response = [
|
||||
{"role": Role.ASSISTANT.value, "content": content} for content in examples[dataset_attr.response][i]
|
||||
]
|
||||
elif dataset_attr.response and isinstance(examples[dataset_attr.response][i], str):
|
||||
response = [{"role": Role.ASSISTANT, "content": examples[dataset_attr.response][i]}]
|
||||
response = [{"role": Role.ASSISTANT.value, "content": examples[dataset_attr.response][i]}]
|
||||
else:
|
||||
response = []
|
||||
|
||||
@@ -49,11 +51,11 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
|
||||
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]:
|
||||
outputs = {"prompt": [], "response": [], "system": [], "tools": []}
|
||||
tag_mapping = {
|
||||
dataset_attr.user_tag: Role.USER,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT,
|
||||
dataset_attr.observation_tag: Role.OBSERVATION,
|
||||
dataset_attr.function_tag: Role.FUNCTION,
|
||||
dataset_attr.system_tag: Role.SYSTEM,
|
||||
dataset_attr.user_tag: Role.USER.value,
|
||||
dataset_attr.assistant_tag: Role.ASSISTANT.value,
|
||||
dataset_attr.observation_tag: Role.OBSERVATION.value,
|
||||
dataset_attr.function_tag: Role.FUNCTION.value,
|
||||
dataset_attr.system_tag: Role.SYSTEM.value,
|
||||
}
|
||||
odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
|
||||
even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
|
||||
|
||||
@@ -75,7 +75,8 @@ class Formatter(ABC):
|
||||
tool_format: Literal["default"] = "default"
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS: ...
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
...
|
||||
|
||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -99,12 +99,12 @@ def preprocess_packed_supervised_dataset(
|
||||
continue
|
||||
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||
template.encode_multiturn(tokenizer, messages, examples["system"][i], examples["tools"][i])
|
||||
for source_ids, target_ids in template.encode_multiturn(
|
||||
tokenizer, messages, examples["system"][i], examples["tools"][i]
|
||||
):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
elif len(input_ids) != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
@@ -122,9 +122,10 @@ def preprocess_packed_supervised_dataset(
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
for i in range(0, total_length, block_size):
|
||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i : i + block_size])
|
||||
if not all(label == IGNORE_INDEX for label in labels[i : i + block_size]):
|
||||
model_inputs["input_ids"].append(input_ids[i : i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i : i + block_size])
|
||||
|
||||
return model_inputs
|
||||
|
||||
@@ -145,7 +146,7 @@ def preprocess_unsupervised_dataset(
|
||||
if len(examples["response"][i]) == 1:
|
||||
messages = examples["prompt"][i] + examples["response"][i]
|
||||
else:
|
||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
|
||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT.value, "content": ""}]
|
||||
|
||||
input_ids, labels = template.encode_oneturn(
|
||||
tokenizer,
|
||||
@@ -180,7 +181,6 @@ def preprocess_pairwise_dataset(
|
||||
|
||||
chosen_messages = examples["prompt"][i] + [examples["response"][i][0]]
|
||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||
tokenizer,
|
||||
chosen_messages,
|
||||
|
||||
@@ -88,16 +88,16 @@ class Template:
|
||||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elif message["role"] == Role.FUNCTION.value:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
@@ -179,16 +179,16 @@ class Llama2Template(Template):
|
||||
elif i > 0 and i % 2 == 0:
|
||||
elements += self.format_separator.apply()
|
||||
|
||||
if message["role"] == Role.USER:
|
||||
if message["role"] == Role.USER.value:
|
||||
elements += self.format_user.apply(content=system_text + message["content"])
|
||||
elif message["role"] == Role.ASSISTANT:
|
||||
elif message["role"] == Role.ASSISTANT.value:
|
||||
elements += self.format_assistant.apply(content=message["content"])
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elif message["role"] == Role.OBSERVATION.value:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elif message["role"] == Role.FUNCTION.value:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
||||
|
||||
@@ -258,7 +258,7 @@ def get_template_and_fix_tokenizer(
|
||||
template = templates["vanilla"] # placeholder
|
||||
else:
|
||||
template = templates.get(name, None)
|
||||
if templates is None:
|
||||
if template is None:
|
||||
raise ValueError("Template {} does not exist.".format(name))
|
||||
|
||||
stop_words = template.stop_words
|
||||
@@ -308,6 +308,15 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="atom",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="baichuan",
|
||||
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
||||
@@ -351,6 +360,21 @@ _register_template(
|
||||
name="chatglm3",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
||||
format_function=FunctionFormatter(slots=["{{name}}\n{{arguments}}"]),
|
||||
format_observation=StringFormatter(
|
||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatglm3_system",
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(
|
||||
slots=[{"token": "[gMASK]"}, {"token": "sop"}, {"token": "<|system|>"}, "\n", "{{content}}"]
|
||||
),
|
||||
@@ -367,13 +391,23 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatml",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="chatml_de",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_separator=EmptyFormatter(slots=["\n"]),
|
||||
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
|
||||
stop_words=["<|im_end|>"],
|
||||
stop_words=["<|im_end|>", "<|im_start|>"],
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
@@ -433,6 +467,16 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="gemma",
|
||||
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
||||
efficient_eos=True,
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
|
||||
_register_template(
|
||||
name="intern",
|
||||
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
|
||||
@@ -495,7 +539,7 @@ _register_template(
|
||||
_register_template(
|
||||
name="openchat",
|
||||
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}]),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||
force_system=True,
|
||||
)
|
||||
|
||||
@@ -324,6 +324,29 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Gemma-2B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b",
|
||||
},
|
||||
"Gemma-7B": {
|
||||
DownloadSource.DEFAULT: "google/gemma-7b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it",
|
||||
},
|
||||
"Gemma-2B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-2b-it",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b",
|
||||
},
|
||||
"Gemma-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "google/gemma-7b-it",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
|
||||
},
|
||||
},
|
||||
template="gemma",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"InternLM-7B": {
|
||||
@@ -543,7 +566,10 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
|
||||
},
|
||||
"Qwen-7B-Chat": {DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"},
|
||||
"Qwen-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat",
|
||||
},
|
||||
"Qwen-14B-Chat": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
||||
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
|
||||
@@ -823,10 +849,18 @@ register_model_group(
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
||||
},
|
||||
"Yi-6B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
|
||||
},
|
||||
"Yi-34B-int8-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
||||
},
|
||||
"Yi-34B-int4-Chat": {
|
||||
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
||||
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
||||
},
|
||||
},
|
||||
template="yi",
|
||||
)
|
||||
@@ -864,3 +898,18 @@ register_model_group(
|
||||
},
|
||||
template="zephyr",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Atom-7B": {
|
||||
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B",
|
||||
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B",
|
||||
},
|
||||
"Atom-7B-Chat": {
|
||||
DownloadSource.DEFAULT: "FlagAlpha/Atom-7B-Chat",
|
||||
DownloadSource.MODELSCOPE: "FlagAlpha/Atom-7B-Chat",
|
||||
},
|
||||
},
|
||||
template="atom",
|
||||
)
|
||||
|
||||
@@ -26,10 +26,6 @@ class FreezeArguments:
|
||||
default=3,
|
||||
metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."},
|
||||
)
|
||||
use_llama_pro: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use llama pro for partial-parameter (freeze) fine-tuning."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -78,6 +74,9 @@ class LoraArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
||||
)
|
||||
use_dora: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}
|
||||
)
|
||||
create_new_adapter: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
||||
@@ -94,7 +93,7 @@ class RLHFArguments:
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."},
|
||||
)
|
||||
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
|
||||
dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto_pair"]] = field(
|
||||
default="sigmoid",
|
||||
metadata={"help": "The type of DPO loss to use."},
|
||||
)
|
||||
@@ -170,6 +169,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."},
|
||||
)
|
||||
use_llama_pro: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||
)
|
||||
disable_version_checking: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to disable version checking."},
|
||||
@@ -195,13 +198,13 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
|
||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
if self.stage == "ppo" and self.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||
|
||||
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
||||
raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
|
||||
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
||||
|
||||
if self.use_llama_pro and self.finetuning_type != "freeze":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze method.")
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA method.")
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
@@ -36,9 +35,9 @@ def _check_dependencies(disabled: bool) -> None:
|
||||
else:
|
||||
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.8.2", "To fix: pip install peft>=0.8.2")
|
||||
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
|
||||
require_version("accelerate>=0.27.2", "To fix: pip install accelerate>=0.27.2")
|
||||
require_version("peft>=0.9.0", "To fix: pip install peft>=0.9.0")
|
||||
require_version("trl>=0.7.11", "To fix: pip install trl>=0.7.11")
|
||||
|
||||
|
||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
@@ -62,7 +61,6 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
||||
|
||||
|
||||
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
@@ -144,7 +142,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
raise ValueError("Please specify `lora_target` in LoRA training.")
|
||||
|
||||
if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
|
||||
raise ValueError("Install Unsloth: https://github.com/unslothai/unsloth")
|
||||
raise ValueError("Unsloth was not installed: https://github.com/unslothai/unsloth")
|
||||
|
||||
if finetuning_args.use_dora:
|
||||
if model_args.quantization_bit is not None:
|
||||
raise ValueError("DoRA does not support quantization.")
|
||||
|
||||
if model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support DoRA.")
|
||||
|
||||
_verify_model_args(model_args, finetuning_args)
|
||||
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
||||
@@ -236,7 +241,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
str(model_args.compute_dtype),
|
||||
)
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .utils import find_all_linear_modules
|
||||
from .utils import find_all_linear_modules, find_expanded_modules
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -82,8 +82,10 @@ def init_adapter(
|
||||
else:
|
||||
param.requires_grad_(False)
|
||||
|
||||
logger.info("Set trainable layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||
adapter_to_resume = None
|
||||
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
@@ -118,6 +120,13 @@ def init_adapter(
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
if finetuning_args.use_llama_pro:
|
||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||
|
||||
if finetuning_args.use_dora:
|
||||
if getattr(model, "quantization_method", None):
|
||||
raise ValueError("DoRA is currently not compatible with quantized models.")
|
||||
|
||||
peft_kwargs = {
|
||||
"r": finetuning_args.lora_rank,
|
||||
"target_modules": target_modules,
|
||||
@@ -136,6 +145,7 @@ def init_adapter(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
modules_to_save=finetuning_args.additional_target,
|
||||
use_dora=finetuning_args.use_dora,
|
||||
**peft_kwargs,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
@@ -157,7 +157,7 @@ def _configure_quantization(
|
||||
config_kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
r"""
|
||||
Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
Priority: PTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
|
||||
"""
|
||||
if getattr(config, "quantization_config", None): # gptq
|
||||
if is_deepspeed_zero3_enabled():
|
||||
@@ -167,7 +167,15 @@ def _configure_quantization(
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
|
||||
quantization_config["use_exllama"] = False # disable exllama
|
||||
logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1)))
|
||||
|
||||
if quantization_config.get("quant_method", None) == "aqlm":
|
||||
quantization_config["bits"] = 2
|
||||
|
||||
logger.info(
|
||||
"Loading {}-bit {}-quantized model.".format(
|
||||
quantization_config.get("bits", "?"), quantization_config.get("quant_method", None)
|
||||
)
|
||||
)
|
||||
|
||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||
require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
|
||||
|
||||
@@ -76,6 +76,33 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
|
||||
r"""
|
||||
Finds the modules in the expanded blocks to apply lora.
|
||||
"""
|
||||
num_layers = getattr(model.config, "num_hidden_layers", None)
|
||||
if not num_layers:
|
||||
raise ValueError("Model was not supported.")
|
||||
|
||||
if num_layers % num_layer_trainable != 0:
|
||||
raise ValueError(
|
||||
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable)
|
||||
)
|
||||
|
||||
stride = num_layers // num_layer_trainable
|
||||
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride)
|
||||
trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids]
|
||||
module_names = []
|
||||
for name, _ in model.named_modules():
|
||||
if any(target_module in name for target_module in target_modules) and any(
|
||||
trainable_layer in name for trainable_layer in trainable_layers
|
||||
):
|
||||
module_names.append(name)
|
||||
|
||||
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||
return module_names
|
||||
|
||||
|
||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Loads value head parameters from Hugging Face Hub or local disk.
|
||||
|
||||
@@ -18,7 +18,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
beta: float,
|
||||
loss_type: Literal["sigmoid", "hinge", "ipo", "kto"],
|
||||
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"],
|
||||
ftx_gamma: float,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
@@ -30,6 +30,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
if ref_model is not None:
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.reference_free = False
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
|
||||
@@ -61,6 +61,7 @@ def run_ppo(
|
||||
use_score_norm=finetuning_args.ppo_score_norm,
|
||||
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
||||
accelerator_kwargs={"step_scheduler_with_optimizer": False},
|
||||
project_kwargs={"logging_dir": training_args.logging_dir},
|
||||
)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
|
||||
@@ -106,6 +106,7 @@ class WebChatModel(ChatModel):
|
||||
def predict(
|
||||
self,
|
||||
chatbot: List[Tuple[str, str]],
|
||||
role: str,
|
||||
query: str,
|
||||
messages: Sequence[Tuple[str, str]],
|
||||
system: str,
|
||||
@@ -115,7 +116,7 @@ class WebChatModel(ChatModel):
|
||||
temperature: float,
|
||||
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
|
||||
chatbot.append([query, ""])
|
||||
query_messages = messages + [{"role": Role.USER, "content": query}]
|
||||
query_messages = messages + [{"role": role, "content": query}]
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
@@ -130,10 +131,10 @@ class WebChatModel(ChatModel):
|
||||
name, arguments = result
|
||||
arguments = json.loads(arguments)
|
||||
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
||||
output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}]
|
||||
output_messages = query_messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
|
||||
bot_text = "```json\n" + tool_call + "\n```"
|
||||
else:
|
||||
output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}]
|
||||
output_messages = query_messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||
bot_text = result
|
||||
|
||||
chatbot[-1] = [query, self.postprocess(bot_text)]
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from ...data import Role
|
||||
from ..utils import check_json_schema
|
||||
|
||||
|
||||
@@ -20,23 +21,24 @@ def create_chat_box(
|
||||
messages = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
|
||||
system = gr.Textbox(show_label=False)
|
||||
tools = gr.Textbox(show_label=False, lines=2)
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
clear_btn = gr.Button()
|
||||
gen_kwargs = engine.chatter.generating_args
|
||||
max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
||||
clear_btn = gr.Button()
|
||||
|
||||
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
|
||||
|
||||
submit_btn.click(
|
||||
engine.chatter.predict,
|
||||
[chatbot, query, messages, system, tools, max_new_tokens, top_p, temperature],
|
||||
[chatbot, role, query, messages, system, tools, max_new_tokens, top_p, temperature],
|
||||
[chatbot, messages],
|
||||
show_progress=True,
|
||||
).then(lambda: gr.update(value=""), outputs=[query])
|
||||
@@ -48,13 +50,14 @@ def create_chat_box(
|
||||
chatbot,
|
||||
messages,
|
||||
dict(
|
||||
role=role,
|
||||
system=system,
|
||||
tools=tools,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
clear_btn=clear_btn,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
clear_btn=clear_btn,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -34,9 +34,7 @@ def create_top() -> Dict[str, "Component"]:
|
||||
|
||||
model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
|
||||
get_model_path, [model_name], [model_path], queue=False
|
||||
).then(
|
||||
get_template, [model_name], [template], queue=False
|
||||
) # do not save config since the below line will save
|
||||
).then(get_template, [model_name], [template], queue=False) # do not save config since the below line will save
|
||||
|
||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1)
|
||||
learning_rate = gr.Textbox(value="5e-5")
|
||||
num_train_epochs = gr.Textbox(value="3.0")
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
@@ -52,8 +52,8 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=1024, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=1024, step=1)
|
||||
batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
|
||||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
@@ -108,27 +108,45 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Freeze config", open=False) as freeze_tab:
|
||||
with gr.Row():
|
||||
num_layer_trainable = gr.Slider(value=3, minimum=1, maximum=128, step=1, scale=2)
|
||||
name_module_trainable = gr.Textbox(scale=3)
|
||||
|
||||
input_elems.update({num_layer_trainable, name_module_trainable})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
freeze_tab=freeze_tab, num_layer_trainable=num_layer_trainable, name_module_trainable=name_module_trainable
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01)
|
||||
lora_target = gr.Textbox()
|
||||
additional_target = gr.Textbox()
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
lora_alpha = gr.Slider(value=16, minimum=1, maximum=2048, step=0.1, scale=1)
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
|
||||
with gr.Column():
|
||||
use_rslora = gr.Checkbox()
|
||||
create_new_adapter = gr.Checkbox()
|
||||
with gr.Row():
|
||||
use_rslora = gr.Checkbox(scale=1)
|
||||
use_dora = gr.Checkbox(scale=1)
|
||||
create_new_adapter = gr.Checkbox(scale=1)
|
||||
additional_target = gr.Textbox(scale=2)
|
||||
|
||||
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, use_rslora, create_new_adapter})
|
||||
input_elems.update(
|
||||
{lora_rank, lora_alpha, lora_dropout, lora_target, use_rslora, use_dora, create_new_adapter, additional_target}
|
||||
)
|
||||
elem_dict.update(
|
||||
dict(
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target,
|
||||
additional_target=additional_target,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
create_new_adapter=create_new_adapter,
|
||||
additional_target=additional_target,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -477,7 +477,7 @@ LOCALES = {
|
||||
},
|
||||
"zh": {
|
||||
"label": "序列打包",
|
||||
"info": "在指令监督微调阶段将序列打包为相同长度的样本。",
|
||||
"info": "在指令监督微调时将序列打包为等长样本。",
|
||||
},
|
||||
},
|
||||
"upcast_layernorm": {
|
||||
@@ -508,6 +508,45 @@ LOCALES = {
|
||||
"info": "仅训练块扩展后的参数。",
|
||||
},
|
||||
},
|
||||
"freeze_tab": {
|
||||
"en": {
|
||||
"label": "Freeze tuning configurations",
|
||||
},
|
||||
"ru": {
|
||||
"label": "конфигурации для настройки заморозки",
|
||||
},
|
||||
"zh": {
|
||||
"label": "部分参数微调设置",
|
||||
},
|
||||
},
|
||||
"num_layer_trainable": {
|
||||
"en": {
|
||||
"label": "Trainable layers",
|
||||
"info": "The number of trainable layers.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Обучаемые слои",
|
||||
"info": "Количество обучаемых слоев.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "可训练层数",
|
||||
"info": "可训练模型层的数量。",
|
||||
},
|
||||
},
|
||||
"name_module_trainable": {
|
||||
"en": {
|
||||
"label": "Trainable modules",
|
||||
"info": "The name of trainable modules. Use commas to separate multiple modules.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Обучаемые модули",
|
||||
"info": "Название обучаемых модулей. Используйте запятые для разделения нескольких модулей.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "可训练模块",
|
||||
"info": "可训练模块的名称。使用英文逗号分隔多个名称。",
|
||||
},
|
||||
},
|
||||
"lora_tab": {
|
||||
"en": {
|
||||
"label": "LoRA configurations",
|
||||
@@ -533,6 +572,20 @@ LOCALES = {
|
||||
"info": "LoRA 矩阵的秩。",
|
||||
},
|
||||
},
|
||||
"lora_alpha": {
|
||||
"en": {
|
||||
"label": "LoRA Alpha",
|
||||
"info": "Lora scaling coefficient.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "LoRA Alpha",
|
||||
"info": "Коэффициент масштабирования LoRA.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 缩放系数",
|
||||
"info": "LoRA 缩放系数大小。",
|
||||
},
|
||||
},
|
||||
"lora_dropout": {
|
||||
"en": {
|
||||
"label": "LoRA Dropout",
|
||||
@@ -561,6 +614,48 @@ LOCALES = {
|
||||
"info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。",
|
||||
},
|
||||
},
|
||||
"use_rslora": {
|
||||
"en": {
|
||||
"label": "Use rslora",
|
||||
"info": "Use the rank stabilization scaling factor for LoRA layer.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Использовать rslora",
|
||||
"info": "Использовать коэффициент масштабирования стабилизации ранга для слоя LoRA.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 rslora",
|
||||
"info": "对 LoRA 层使用秩稳定缩放方法。",
|
||||
},
|
||||
},
|
||||
"use_dora": {
|
||||
"en": {
|
||||
"label": "Use DoRA",
|
||||
"info": "Use weight-decomposed LoRA.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Используйте DoRA",
|
||||
"info": "Используйте LoRA с декомпозицией весов.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 DoRA",
|
||||
"info": "使用权重分解的 LoRA。",
|
||||
},
|
||||
},
|
||||
"create_new_adapter": {
|
||||
"en": {
|
||||
"label": "Create new adapter",
|
||||
"info": "Create a new adapter with randomly initialized weight upon the existing one.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Создать новый адаптер",
|
||||
"info": "Создать новый адаптер с случайной инициализацией веса на основе существующего.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "新建适配器",
|
||||
"info": "在现有的适配器上创建一个随机初始化后的新适配器。",
|
||||
},
|
||||
},
|
||||
"additional_target": {
|
||||
"en": {
|
||||
"label": "Additional modules (optional)",
|
||||
@@ -578,34 +673,6 @@ LOCALES = {
|
||||
"info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。",
|
||||
},
|
||||
},
|
||||
"use_rslora": {
|
||||
"en": {
|
||||
"label": "Use rslora",
|
||||
"info": "Use the rank stabilization scaling factor for LoRA layer.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Использовать rslora",
|
||||
"info": "Использовать коэффициент масштабирования стабилизации ранга для слоя LoRA.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 rslora",
|
||||
"info": "对 LoRA 层使用秩稳定缩放方法。",
|
||||
},
|
||||
},
|
||||
"create_new_adapter": {
|
||||
"en": {
|
||||
"label": "Create new adapter",
|
||||
"info": "Create a new adapter with randomly initialized weight upon the existing one.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Создать новый адаптер",
|
||||
"info": "Создать новый адаптер с случайной инициализацией веса на основе существующего.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "新建适配器",
|
||||
"info": "在现有的适配器上创建一个随机初始化后的新适配器。",
|
||||
},
|
||||
},
|
||||
"rlhf_tab": {
|
||||
"en": {
|
||||
"label": "RLHF configurations",
|
||||
@@ -772,6 +839,17 @@ LOCALES = {
|
||||
"value": "模型未加载,请先加载模型。",
|
||||
},
|
||||
},
|
||||
"role": {
|
||||
"en": {
|
||||
"label": "Role",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Роль",
|
||||
},
|
||||
"zh": {
|
||||
"label": "角色",
|
||||
},
|
||||
},
|
||||
"system": {
|
||||
"en": {
|
||||
"placeholder": "System prompt (optional)",
|
||||
@@ -816,17 +894,6 @@ LOCALES = {
|
||||
"value": "提交",
|
||||
},
|
||||
},
|
||||
"clear_btn": {
|
||||
"en": {
|
||||
"value": "Clear history",
|
||||
},
|
||||
"ru": {
|
||||
"value": "Очистить историю",
|
||||
},
|
||||
"zh": {
|
||||
"value": "清空历史",
|
||||
},
|
||||
},
|
||||
"max_length": {
|
||||
"en": {
|
||||
"label": "Maximum length",
|
||||
@@ -871,6 +938,17 @@ LOCALES = {
|
||||
"label": "温度系数",
|
||||
},
|
||||
},
|
||||
"clear_btn": {
|
||||
"en": {
|
||||
"value": "Clear history",
|
||||
},
|
||||
"ru": {
|
||||
"value": "Очистить историю",
|
||||
},
|
||||
"zh": {
|
||||
"value": "清空历史",
|
||||
},
|
||||
},
|
||||
"max_shard_size": {
|
||||
"en": {
|
||||
"label": "Max shard size (GB)",
|
||||
@@ -1016,6 +1094,11 @@ ALERTS = {
|
||||
"ru": "Неверная схема JSON.",
|
||||
"zh": "Json 格式错误。",
|
||||
},
|
||||
"warn_no_cuda": {
|
||||
"en": "CUDA environment was not detected.",
|
||||
"ru": "Среда CUDA не обнаружена.",
|
||||
"zh": "未检测到 CUDA 环境。",
|
||||
},
|
||||
"info_aborting": {
|
||||
"en": "Aborted, wait for terminating...",
|
||||
"ru": "Прервано, ожидание завершения...",
|
||||
|
||||
@@ -8,6 +8,7 @@ import gradio as gr
|
||||
import transformers
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.utils import is_torch_cuda_available
|
||||
|
||||
from ..extras.callbacks import LogCallback
|
||||
from ..extras.constants import TRAINING_STAGES
|
||||
@@ -64,12 +65,15 @@ class Runner:
|
||||
if len(dataset) == 0:
|
||||
return ALERTS["err_no_dataset"][lang]
|
||||
|
||||
if self.demo_mode and (not from_preview):
|
||||
if not from_preview and self.demo_mode:
|
||||
return ALERTS["err_demo"][lang]
|
||||
|
||||
if not from_preview and get_device_count() > 1:
|
||||
return ALERTS["err_device_count"][lang]
|
||||
|
||||
if not from_preview and not is_torch_cuda_available():
|
||||
gr.Warning(ALERTS["warn_no_cuda"][lang])
|
||||
|
||||
self.aborted = False
|
||||
self.logger_handler.reset()
|
||||
self.trainer_callback = LogCallback(self)
|
||||
@@ -129,26 +133,36 @@ class Runner:
|
||||
sft_packing=get("train.sft_packing"),
|
||||
upcast_layernorm=get("train.upcast_layernorm"),
|
||||
use_llama_pro=get("train.use_llama_pro"),
|
||||
lora_rank=get("train.lora_rank"),
|
||||
lora_dropout=get("train.lora_dropout"),
|
||||
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
|
||||
additional_target=get("train.additional_target") or None,
|
||||
use_rslora=get("train.use_rslora"),
|
||||
create_new_adapter=get("train.create_new_adapter"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
||||
fp16=(get("train.compute_type") == "fp16"),
|
||||
bf16=(get("train.compute_type") == "bf16"),
|
||||
)
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
||||
args["create_new_adapter"] = args["quantization_bit"] is None
|
||||
if args["finetuning_type"] == "freeze":
|
||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||
args["name_module_trainable"] = get("train.name_module_trainable")
|
||||
elif args["finetuning_type"] == "lora":
|
||||
args["lora_rank"] = int(get("train.lora_rank"))
|
||||
args["lora_alpha"] = float(get("train.lora_alpha"))
|
||||
args["lora_dropout"] = float(get("train.lora_dropout"))
|
||||
args["lora_target"] = get("train.lora_target") or get_module(get("top.model_name"))
|
||||
args["use_rslora"] = get("train.use_rslora")
|
||||
args["use_dora"] = get("train.use_dora")
|
||||
args["additional_target"] = get("train.additional_target") or None
|
||||
if args["stage"] in ["rm", "ppo", "dpo"]:
|
||||
args["create_new_adapter"] = args["quantization_bit"] is None
|
||||
else:
|
||||
args["create_new_adapter"] = get("train.create_new_adapter")
|
||||
|
||||
if args["use_llama_pro"]:
|
||||
args["num_layer_trainable"] = int(get("train.num_layer_trainable"))
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
|
||||
)
|
||||
args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"
|
||||
args["reward_model_type"] = "lora" if args["finetuning_type"] == "lora" else "full"
|
||||
|
||||
if args["stage"] == "dpo":
|
||||
args["dpo_beta"] = get("train.dpo_beta")
|
||||
@@ -157,8 +171,9 @@ class Runner:
|
||||
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
||||
args["val_size"] = get("train.val_size")
|
||||
args["evaluation_strategy"] = "steps"
|
||||
args["eval_steps"] = get("train.save_steps")
|
||||
args["load_best_model_at_end"] = True
|
||||
args["eval_steps"] = args["save_steps"]
|
||||
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
|
||||
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@@ -44,11 +44,14 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
def check_json_schema(text: str, lang: str) -> None:
|
||||
try:
|
||||
tools = json.loads(text)
|
||||
for tool in tools:
|
||||
assert "name" in tool
|
||||
except AssertionError:
|
||||
if tools:
|
||||
assert isinstance(tools, list)
|
||||
for tool in tools:
|
||||
if "name" not in tool:
|
||||
raise ValueError("Name not found.")
|
||||
except ValueError:
|
||||
gr.Warning(ALERTS["err_tool_name"][lang])
|
||||
except json.JSONDecodeError:
|
||||
except Exception:
|
||||
gr.Warning(ALERTS["err_json_schema"][lang])
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user