Compare commits
96 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6eed1db36c | ||
|
|
948124f55e | ||
|
|
2b191ca776 | ||
|
|
be4d2822ea | ||
|
|
736ddd0319 | ||
|
|
dfa289aa72 | ||
|
|
c2644f939a | ||
|
|
f11c1ae562 | ||
|
|
3126164aa6 | ||
|
|
ed10486cad | ||
|
|
04fa430c6c | ||
|
|
fa1893b59c | ||
|
|
e993e717a5 | ||
|
|
c80e56423a | ||
|
|
ffa09a01d6 | ||
|
|
7d04f8567b | ||
|
|
baa709674f | ||
|
|
ca9a494d0c | ||
|
|
37eb8c05cc | ||
|
|
7c046edb7b | ||
|
|
22cea38b20 | ||
|
|
ef2ca0a827 | ||
|
|
7f0b908de2 | ||
|
|
5fc5e776ff | ||
|
|
93b281c016 | ||
|
|
9585699918 | ||
|
|
bceaba551d | ||
|
|
0bfeed3a7e | ||
|
|
70a780c3c0 | ||
|
|
d74ab5306c | ||
|
|
688e8601ab | ||
|
|
4933ab5956 | ||
|
|
6c7225a5d4 | ||
|
|
a22982f2fa | ||
|
|
c95479dddb | ||
|
|
fc48bd8da0 | ||
|
|
d5323bfa3f | ||
|
|
e9d4a2b507 | ||
|
|
37bcbe8046 | ||
|
|
fdfb644f0a | ||
|
|
cde9f3db57 | ||
|
|
8bf5a98815 | ||
|
|
be566a15a5 | ||
|
|
d5f1b99ac4 | ||
|
|
2144bb0e27 | ||
|
|
bc665bacc7 | ||
|
|
52bfcf4883 | ||
|
|
06df3d6fb6 | ||
|
|
ca719a8697 | ||
|
|
72dfd74005 | ||
|
|
69302c4420 | ||
|
|
42d7019b2e | ||
|
|
5f0d0d6b9b | ||
|
|
76cb63e4f6 | ||
|
|
467d571206 | ||
|
|
972bfa700a | ||
|
|
458955d0fb | ||
|
|
990eeccf45 | ||
|
|
a3a7465f00 | ||
|
|
031a819257 | ||
|
|
eb4b4e3c8c | ||
|
|
d2e1fe9b1d | ||
|
|
6e27a9e39a | ||
|
|
805478c911 | ||
|
|
a281cdeb89 | ||
|
|
cda698a67f | ||
|
|
15acd17716 | ||
|
|
34a2bddfcd | ||
|
|
370f817549 | ||
|
|
041390c37e | ||
|
|
d9fe4bf500 | ||
|
|
e0c7e944fc | ||
|
|
0845fe67db | ||
|
|
fe3b12d900 | ||
|
|
a70d56864e | ||
|
|
fdbb2c5378 | ||
|
|
3c0aaf42af | ||
|
|
438e19160a | ||
|
|
f2b2ff6950 | ||
|
|
86cef96305 | ||
|
|
5f50944baf | ||
|
|
0804fd2353 | ||
|
|
86419eb457 | ||
|
|
76f3ae7bf3 | ||
|
|
aaa85190eb | ||
|
|
e2a4e926b9 | ||
|
|
d6e922dc1c | ||
|
|
27f4317ec6 | ||
|
|
e434348216 | ||
|
|
2e19afedb8 | ||
|
|
da08fa7c63 | ||
|
|
9c96b97dc7 | ||
|
|
28a51b622b | ||
|
|
8bd1da7144 | ||
|
|
e4d0b8ee6e | ||
|
|
1dfb28b362 |
160
.gitignore
vendored
Normal file
160
.gitignore
vendored
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
255
README.md
255
README.md
@@ -12,17 +12,25 @@
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset.
|
[23/08/18] Now we support **resuming training**, upgrade `transformers` to `4.31.0` to enjoy this feature.
|
||||||
|
|
||||||
|
[23/08/12] Now we support **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
|
||||||
|
|
||||||
|
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
|
||||||
|
|
||||||
|
[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.
|
||||||
|
|
||||||
|
[23/07/31] Now we support **dataset streaming**. Try `--streaming` and `--max_steps 10000` arguments to load your dataset in streaming mode.
|
||||||
|
|
||||||
[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details.
|
[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details.
|
||||||
|
|
||||||
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
|
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
|
||||||
|
|
||||||
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
[23/07/18] Now we develop an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
||||||
|
|
||||||
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model.
|
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model.
|
||||||
|
|
||||||
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
|
[23/07/09] Now we release **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
|
||||||
|
|
||||||
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model.
|
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model.
|
||||||
|
|
||||||
@@ -40,28 +48,33 @@
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
| Model | Model size | Default module | Template |
|
||||||
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B)
|
| -------------------------------------------------------- | --------------------------- | ----------------- |----------|
|
||||||
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
|
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
- [InternLM](https://github.com/InternLM/InternLM) (7B)
|
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
|
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
|
||||||
|
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||||
|
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
|
||||||
|
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
||||||
|
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
|
||||||
|
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
||||||
|
|
||||||
|
- **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
|
||||||
|
- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
|
||||||
|
|
||||||
## Supported Training Approaches
|
## Supported Training Approaches
|
||||||
|
|
||||||
- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
|
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
|
||||||
- Full-parameter tuning
|
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||||
- Partial-parameter tuning
|
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
| Reward Modeling | | | :white_check_mark: | :white_check_mark: |
|
||||||
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
|
| PPO Training | | | :white_check_mark: | :white_check_mark: |
|
||||||
- Full-parameter tuning
|
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
|
||||||
- Partial-parameter tuning
|
|
||||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
- Use `--quantization_bit 4/8` argument to enable QLoRA.
|
||||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
|
||||||
- [RLHF](https://arxiv.org/abs/2203.02155)
|
|
||||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
|
||||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
|
||||||
|
|
||||||
## Provided Datasets
|
## Provided Datasets
|
||||||
|
|
||||||
@@ -78,7 +91,6 @@
|
|||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [Self-cognition (zh)](data/self_cognition.json)
|
- [Self-cognition (zh)](data/self_cognition.json)
|
||||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||||
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
|
|
||||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||||
@@ -93,7 +105,7 @@
|
|||||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||||
- For reward modelling:
|
- For reward modeling or DPO training:
|
||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
@@ -111,6 +123,7 @@ huggingface-cli login
|
|||||||
|
|
||||||
- Python 3.8+ and PyTorch 1.13.1+
|
- Python 3.8+ and PyTorch 1.13.1+
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||||
|
- sentencepiece and tiktoken
|
||||||
- jieba, rouge-chinese and nltk (used at evaluation)
|
- jieba, rouge-chinese and nltk (used at evaluation)
|
||||||
- gradio and matplotlib (used in web_demo.py)
|
- gradio and matplotlib (used in web_demo.py)
|
||||||
- uvicorn, fastapi and sse-starlette (used in api_demo.py)
|
- uvicorn, fastapi and sse-starlette (used in api_demo.py)
|
||||||
@@ -128,7 +141,6 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
|
|||||||
### Dependence Installation (optional)
|
### Dependence Installation (optional)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git lfs install
|
|
||||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||||
conda create -n llama_etuning python=3.10
|
conda create -n llama_etuning python=3.10
|
||||||
conda activate llama_etuning
|
conda activate llama_etuning
|
||||||
@@ -148,18 +160,23 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
|||||||
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
We strongly recommend using the all-in-one Web UI for newcomers since it can also generate training scripts **automatically**.
|
||||||
|
|
||||||
Currently the web UI only supports training on **a single GPU**.
|
Currently the web UI only supports training on **a single GPU**.
|
||||||
|
|
||||||
### (Continually) Pre-Training
|
### Train on a single GPU
|
||||||
|
|
||||||
|
#### Pre-Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage pt \
|
--stage pt \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset wiki_demo \
|
--dataset wiki_demo \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--output_dir path_to_pt_checkpoint \
|
--output_dir path_to_pt_checkpoint \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
@@ -173,16 +190,17 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
### Supervised Fine-Tuning
|
#### Supervised Fine-Tuning
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--output_dir path_to_sft_checkpoint \
|
--output_dir path_to_sft_checkpoint \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
@@ -196,42 +214,42 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
Remember to specify `--lora_target W_pack` if you are using Baichuan models.
|
#### Reward Modeling
|
||||||
|
|
||||||
### Reward Model Training
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage rm \
|
--stage rm \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset comparison_gpt4_en \
|
--dataset comparison_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
--output_dir path_to_rm_checkpoint \
|
--output_dir path_to_rm_checkpoint \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 2 \
|
||||||
--gradient_accumulation_steps 4 \
|
--gradient_accumulation_steps 4 \
|
||||||
--lr_scheduler_type cosine \
|
--lr_scheduler_type cosine \
|
||||||
--logging_steps 10 \
|
--logging_steps 10 \
|
||||||
--save_steps 1000 \
|
--save_steps 1000 \
|
||||||
--learning_rate 1e-5 \
|
--learning_rate 1e-6 \
|
||||||
--num_train_epochs 1.0 \
|
--num_train_epochs 1.0 \
|
||||||
--plot_loss \
|
--plot_loss \
|
||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
### PPO Training (RLHF)
|
#### PPO Training
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage ppo \
|
--stage ppo \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
--reward_model path_to_rm_checkpoint \
|
--reward_model path_to_rm_checkpoint \
|
||||||
@@ -243,17 +261,45 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--save_steps 1000 \
|
--save_steps 1000 \
|
||||||
--learning_rate 1e-5 \
|
--learning_rate 1e-5 \
|
||||||
--num_train_epochs 1.0 \
|
--num_train_epochs 1.0 \
|
||||||
--plot_loss
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
|
```
|
||||||
|
|
||||||
|
#### DPO Training
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
--stage dpo \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--do_train \
|
||||||
|
--dataset comparison_gpt4_en \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--resume_lora_training False \
|
||||||
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
|
--output_dir path_to_dpo_checkpoint \
|
||||||
|
--per_device_train_batch_size 2 \
|
||||||
|
--gradient_accumulation_steps 4 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 1000 \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
### Distributed Training
|
### Distributed Training
|
||||||
|
|
||||||
|
#### Use Huggingface Accelerate
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
accelerate config # configure the environment
|
accelerate config # configure the environment
|
||||||
accelerate launch src/train_bash.py # arguments (same as above)
|
accelerate launch src/train_bash.py # arguments (same as above)
|
||||||
```
|
```
|
||||||
|
|
||||||
<details><summary>Example configuration for full-tuning with DeepSpeed ZeRO-2</summary>
|
<details><summary>Example config.yaml for training with DeepSpeed ZeRO-2</summary>
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
compute_environment: LOCAL_MACHINE
|
compute_environment: LOCAL_MACHINE
|
||||||
@@ -281,12 +327,93 @@ use_cpu: false
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
#### Use DeepSpeed
|
||||||
|
|
||||||
|
```bash
|
||||||
|
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||||
|
--deepspeed ds_config.json \
|
||||||
|
... # arguments (same as above)
|
||||||
|
```
|
||||||
|
|
||||||
|
<details><summary>Example ds_config.json for training with DeepSpeed ZeRO-2</summary>
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"zero_allow_untested_optimizer": true,
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 16,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"allgather_partitions": true,
|
||||||
|
"allgather_bucket_size": 5e8,
|
||||||
|
"reduce_scatter": true,
|
||||||
|
"reduce_bucket_size": 5e8,
|
||||||
|
"overlap_comm": false,
|
||||||
|
"contiguous_gradients": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### Export model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/export_model.py \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--checkpoint_dir path_to_checkpoint \
|
||||||
|
--output_dir path_to_export
|
||||||
|
```
|
||||||
|
|
||||||
|
### API Demo
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/api_demo.py \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--checkpoint_dir path_to_checkpoint
|
||||||
|
```
|
||||||
|
|
||||||
|
Visit `http://localhost:8000/docs` for API documentation.
|
||||||
|
|
||||||
|
### CLI Demo
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/cli_demo.py \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--checkpoint_dir path_to_checkpoint
|
||||||
|
```
|
||||||
|
|
||||||
|
### Web Demo
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python src/web_demo.py \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--checkpoint_dir path_to_checkpoint
|
||||||
|
```
|
||||||
|
|
||||||
### Evaluation (BLEU and ROUGE_CHINESE)
|
### Evaluation (BLEU and ROUGE_CHINESE)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_eval \
|
--do_eval \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
@@ -305,7 +432,7 @@ We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128
|
|||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_predict \
|
--do_predict \
|
||||||
--dataset alpaca_gpt4_en \
|
--dataset alpaca_gpt4_en \
|
||||||
--template default \
|
--template default \
|
||||||
@@ -317,49 +444,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--predict_with_generate
|
--predict_with_generate
|
||||||
```
|
```
|
||||||
|
|
||||||
### API Demo
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python src/api_demo.py \
|
|
||||||
--model_name_or_path path_to_your_model \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--checkpoint_dir path_to_checkpoint
|
|
||||||
```
|
|
||||||
|
|
||||||
Visit `http://localhost:8000/docs` for API documentation.
|
|
||||||
|
|
||||||
### CLI Demo
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python src/cli_demo.py \
|
|
||||||
--model_name_or_path path_to_your_model \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--checkpoint_dir path_to_checkpoint
|
|
||||||
```
|
|
||||||
|
|
||||||
### Web Demo
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python src/web_demo.py \
|
|
||||||
--model_name_or_path path_to_your_model \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--checkpoint_dir path_to_checkpoint
|
|
||||||
```
|
|
||||||
|
|
||||||
### Export model
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python src/export_model.py \
|
|
||||||
--model_name_or_path path_to_your_model \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--checkpoint_dir path_to_checkpoint \
|
|
||||||
--output_dir path_to_export
|
|
||||||
```
|
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
||||||
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)).
|
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)).
|
||||||
@@ -378,6 +462,9 @@ Please follow the model licenses to use the corresponding model weights:
|
|||||||
- [Falcon](LICENSE)
|
- [Falcon](LICENSE)
|
||||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||||
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
|
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
|
||||||
|
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
|
||||||
|
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
||||||
|
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
|||||||
240
README_zh.md
240
README_zh.md
@@ -12,73 +12,85 @@
|
|||||||
|
|
||||||
## 更新日志
|
## 更新日志
|
||||||
|
|
||||||
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming` 和 `--max_steps 100` 参数来流式加载数据集。
|
[23/08/18] 现在我们支持了**训练状态恢复**,请将 `transformers` 升级至 `4.31.0` 以启用此功能。
|
||||||
|
|
||||||
|
[23/08/12] 现在我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请尝试使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
|
||||||
|
|
||||||
|
[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-训练)(实验性功能)。
|
||||||
|
|
||||||
|
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat` 和 `--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型时请添加 `--template chatml` 参数。
|
||||||
|
|
||||||
|
[23/07/31] 现在我们支持了**数据流式加载**。请尝试使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。
|
||||||
|
|
||||||
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。
|
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。
|
||||||
|
|
||||||
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。
|
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。使用 LLaMA-2-chat 模型时请添加 `--template llama2` 参数。
|
||||||
|
|
||||||
[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
|
[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
|
||||||
|
|
||||||
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model` 和 `--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。
|
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base` 和 `--lora_target W_pack` 参数。使用 Baichuan-13B-Chat 模型时请添加 `--template baichuan` 参数。
|
||||||
|
|
||||||
[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
|
[23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
|
||||||
|
|
||||||
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数。
|
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。使用 InternLM-chat 模型时请添加 `--template intern` 参数。
|
||||||
|
|
||||||
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b` 和 `--lora_target query_key_value` 参数。
|
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b` 和 `--lora_target query_key_value` 参数。
|
||||||
|
|
||||||
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft)。
|
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft)。
|
||||||
|
|
||||||
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入任意基于 ChatGPT 的应用中。
|
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。
|
||||||
|
|
||||||
[23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B` 和 `--lora_target W_pack` 参数。
|
[23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B` 和 `--lora_target W_pack` 参数。
|
||||||
|
|
||||||
[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 [QLoRA](https://github.com/artidoro/qlora))。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
||||||
|
|
||||||
[23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt` 和 `--lora_target query_key_value` 参数。
|
[23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt` 和 `--lora_target query_key_value` 参数。
|
||||||
|
|
||||||
## 模型
|
## 模型
|
||||||
|
|
||||||
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||||
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B)
|
| -------------------------------------------------------- | --------------------------- | ----------------- |----------|
|
||||||
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||||
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
|
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
- [InternLM](https://github.com/InternLM/InternLM) (7B)
|
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
|
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
|
||||||
|
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||||
|
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
|
||||||
|
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
||||||
|
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
|
||||||
|
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
|
||||||
|
|
||||||
## 微调方法
|
- **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。
|
||||||
|
- 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用对应的模板。
|
||||||
|
|
||||||
- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
|
## 训练方法
|
||||||
- 全参数微调
|
|
||||||
- 部分参数微调
|
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
|
||||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
- [指令监督微调](https://arxiv.org/abs/2109.01652)
|
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||||
- 全参数微调
|
| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
|
||||||
- 部分参数微调
|
| PPO 训练 | | | :white_check_mark: | :white_check_mark: |
|
||||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
|
||||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
|
||||||
- [人类反馈的强化学习(RLHF)](https://arxiv.org/abs/2203.02155)
|
- 使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
|
||||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
|
||||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
|
||||||
|
|
||||||
## 数据集
|
## 数据集
|
||||||
|
|
||||||
- 用于二次预训练:
|
- 用于预训练:
|
||||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||||
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||||
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
||||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||||
- 用于指令监督微调:
|
- 用于指令监督微调:
|
||||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [Self-cognition (zh)](data/self_cognition.json)
|
- [Self-cognition (zh)](data/self_cognition.json)
|
||||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||||
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
|
|
||||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||||
@@ -93,7 +105,7 @@
|
|||||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||||
- 用于奖励模型训练:
|
- 用于奖励模型或 DPO 训练:
|
||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
@@ -111,6 +123,7 @@ huggingface-cli login
|
|||||||
|
|
||||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
- Python 3.8+ 和 PyTorch 1.13.1+
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
||||||
|
- sentencepiece 和 tiktoken
|
||||||
- jieba, rouge-chinese 和 nltk (用于评估)
|
- jieba, rouge-chinese 和 nltk (用于评估)
|
||||||
- gradio 和 matplotlib (用于网页端交互)
|
- gradio 和 matplotlib (用于网页端交互)
|
||||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
||||||
@@ -128,7 +141,6 @@ huggingface-cli login
|
|||||||
### 环境搭建(可跳过)
|
### 环境搭建(可跳过)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git lfs install
|
|
||||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||||
conda create -n llama_etuning python=3.10
|
conda create -n llama_etuning python=3.10
|
||||||
conda activate llama_etuning
|
conda activate llama_etuning
|
||||||
@@ -142,24 +154,29 @@ pip install -r requirements.txt
|
|||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
### 浏览器一键微调/测试
|
### 浏览器一体化界面
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
CUDA_VISIBLE_DEVICES=0 python src/train_web.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
我们极力推荐新手使用浏览器一体化界面,因为它还可以**自动**生成运行所需的命令行脚本。
|
||||||
|
|
||||||
目前网页 UI 仅支持**单卡训练**。
|
目前网页 UI 仅支持**单卡训练**。
|
||||||
|
|
||||||
### 二次预训练
|
### 单 GPU 训练
|
||||||
|
|
||||||
|
#### 预训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage pt \
|
--stage pt \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset wiki_demo \
|
--dataset wiki_demo \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--output_dir path_to_pt_checkpoint \
|
--output_dir path_to_pt_checkpoint \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
@@ -173,16 +190,17 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
### 指令监督微调
|
#### 指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage sft \
|
--stage sft \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_zh \
|
--dataset alpaca_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--output_dir path_to_sft_checkpoint \
|
--output_dir path_to_sft_checkpoint \
|
||||||
--overwrite_cache \
|
--overwrite_cache \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 4 \
|
||||||
@@ -196,42 +214,42 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
使用 Baichuan 模型时请指定 `--lora_target W_pack` 参数。
|
#### 奖励模型训练
|
||||||
|
|
||||||
### 奖励模型训练
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage rm \
|
--stage rm \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset comparison_gpt4_zh \
|
--dataset comparison_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
--output_dir path_to_rm_checkpoint \
|
--output_dir path_to_rm_checkpoint \
|
||||||
--per_device_train_batch_size 4 \
|
--per_device_train_batch_size 2 \
|
||||||
--gradient_accumulation_steps 4 \
|
--gradient_accumulation_steps 4 \
|
||||||
--lr_scheduler_type cosine \
|
--lr_scheduler_type cosine \
|
||||||
--logging_steps 10 \
|
--logging_steps 10 \
|
||||||
--save_steps 1000 \
|
--save_steps 1000 \
|
||||||
--learning_rate 1e-5 \
|
--learning_rate 1e-6 \
|
||||||
--num_train_epochs 1.0 \
|
--num_train_epochs 1.0 \
|
||||||
--plot_loss \
|
--plot_loss \
|
||||||
--fp16
|
--fp16
|
||||||
```
|
```
|
||||||
|
|
||||||
### RLHF 训练
|
#### PPO 训练
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--stage ppo \
|
--stage ppo \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--do_train \
|
--do_train \
|
||||||
--dataset alpaca_gpt4_zh \
|
--dataset alpaca_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
--resume_lora_training False \
|
--resume_lora_training False \
|
||||||
--checkpoint_dir path_to_sft_checkpoint \
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
--reward_model path_to_rm_checkpoint \
|
--reward_model path_to_rm_checkpoint \
|
||||||
@@ -246,8 +264,35 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
--plot_loss
|
--plot_loss
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### DPO 训练
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
--stage dpo \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--do_train \
|
||||||
|
--dataset comparison_gpt4_zh \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--lora_target q_proj,v_proj \
|
||||||
|
--resume_lora_training False \
|
||||||
|
--checkpoint_dir path_to_sft_checkpoint \
|
||||||
|
--output_dir path_to_dpo_checkpoint \
|
||||||
|
--per_device_train_batch_size 2 \
|
||||||
|
--gradient_accumulation_steps 4 \
|
||||||
|
--lr_scheduler_type cosine \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--save_steps 1000 \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_train_epochs 1.0 \
|
||||||
|
--plot_loss \
|
||||||
|
--fp16
|
||||||
|
```
|
||||||
|
|
||||||
### 多 GPU 分布式训练
|
### 多 GPU 分布式训练
|
||||||
|
|
||||||
|
#### 使用 Huggingface Accelerate
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
accelerate config # 首先配置分布式环境
|
accelerate config # 首先配置分布式环境
|
||||||
accelerate launch src/train_bash.py # 参数同上
|
accelerate launch src/train_bash.py # 参数同上
|
||||||
@@ -281,47 +326,60 @@ use_cpu: false
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### 指标评估(BLEU分数和汉语ROUGE分数)
|
#### 使用 DeepSpeed
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||||
--stage sft \
|
--deepspeed ds_config.json \
|
||||||
--model_name_or_path path_to_your_model \
|
... # 参数同上
|
||||||
--do_eval \
|
|
||||||
--dataset alpaca_gpt4_zh \
|
|
||||||
--template default \
|
|
||||||
--finetuning_type lora \
|
|
||||||
--checkpoint_dir path_to_checkpoint \
|
|
||||||
--output_dir path_to_eval_result \
|
|
||||||
--per_device_eval_batch_size 8 \
|
|
||||||
--max_samples 100 \
|
|
||||||
--predict_with_generate
|
|
||||||
```
|
```
|
||||||
|
|
||||||
我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128` 参数。
|
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数微调的 DeepSpeed 配置示例</summary>
|
||||||
|
|
||||||
### 模型预测
|
```json
|
||||||
|
{
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"zero_allow_untested_optimizer": true,
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 16,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 2,
|
||||||
|
"allgather_partitions": true,
|
||||||
|
"allgather_bucket_size": 5e8,
|
||||||
|
"reduce_scatter": true,
|
||||||
|
"reduce_bucket_size": 5e8,
|
||||||
|
"overlap_comm": false,
|
||||||
|
"contiguous_gradients": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 导出微调后的模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
python src/export_model.py \
|
||||||
--stage sft \
|
--model_name_or_path path_to_llama_model \
|
||||||
--model_name_or_path path_to_your_model \
|
|
||||||
--do_predict \
|
|
||||||
--dataset alpaca_gpt4_zh \
|
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_predict_result \
|
--output_dir path_to_export
|
||||||
--per_device_eval_batch_size 8 \
|
|
||||||
--max_samples 100 \
|
|
||||||
--predict_with_generate
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### API 服务
|
### API 服务
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/api_demo.py \
|
python src/api_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
@@ -333,7 +391,7 @@ python src/api_demo.py \
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
python src/cli_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
@@ -343,21 +401,46 @@ python src/cli_demo.py \
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/web_demo.py \
|
python src/web_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_llama_model \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
|
|
||||||
### 导出微调模型
|
### 指标评估(BLEU 分数和汉语 ROUGE 分数)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/export_model.py \
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
--model_name_or_path path_to_your_model \
|
--stage sft \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--do_eval \
|
||||||
|
--dataset alpaca_gpt4_zh \
|
||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_export
|
--output_dir path_to_eval_result \
|
||||||
|
--per_device_eval_batch_size 8 \
|
||||||
|
--max_samples 100 \
|
||||||
|
--predict_with_generate
|
||||||
|
```
|
||||||
|
|
||||||
|
我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
|
||||||
|
|
||||||
|
### 模型预测
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||||
|
--stage sft \
|
||||||
|
--model_name_or_path path_to_llama_model \
|
||||||
|
--do_predict \
|
||||||
|
--dataset alpaca_gpt4_zh \
|
||||||
|
--template default \
|
||||||
|
--finetuning_type lora \
|
||||||
|
--checkpoint_dir path_to_checkpoint \
|
||||||
|
--output_dir path_to_predict_result \
|
||||||
|
--per_device_eval_batch_size 8 \
|
||||||
|
--max_samples 100 \
|
||||||
|
--predict_with_generate
|
||||||
```
|
```
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
@@ -378,6 +461,9 @@ python src/export_model.py \
|
|||||||
- [Falcon](LICENSE)
|
- [Falcon](LICENSE)
|
||||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||||
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
|
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
|
||||||
|
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
|
||||||
|
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
||||||
|
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
|
||||||
|
|
||||||
## 引用
|
## 引用
|
||||||
|
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
f967a4f6d04a11308a15524aa9a846a19a8d1e83
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
0a4f0d74fd1c5cab2eb6d84a3a3fe669847becd8
|
|
||||||
@@ -3,8 +3,10 @@ transformers>=4.29.1
|
|||||||
datasets>=2.12.0
|
datasets>=2.12.0
|
||||||
accelerate>=0.21.0
|
accelerate>=0.21.0
|
||||||
peft>=0.4.0
|
peft>=0.4.0
|
||||||
trl>=0.4.7
|
trl>=0.5.0
|
||||||
|
scipy
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
tiktoken
|
||||||
jieba
|
jieba
|
||||||
rouge-chinese
|
rouge-chinese
|
||||||
nltk
|
nltk
|
||||||
|
|||||||
@@ -1,19 +1,13 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
|
|
||||||
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
|
||||||
# Visit http://localhost:8000/docs for document.
|
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from llmtuner import ChatModel
|
from llmtuner import ChatModel, create_app
|
||||||
from llmtuner.api.app import create_app
|
|
||||||
from llmtuner.tuner import get_infer_args
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
chat_model = ChatModel(*get_infer_args())
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||||
|
print("Visit http://localhost:8000/docs for API document.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,13 +1,8 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Implements stream chat in command line for fine-tuned models.
|
|
||||||
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
|
||||||
|
|
||||||
from llmtuner import ChatModel
|
from llmtuner import ChatModel
|
||||||
from llmtuner.tuner import get_infer_args
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
chat_model = ChatModel(*get_infer_args())
|
chat_model = ChatModel()
|
||||||
history = []
|
history = []
|
||||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,8 @@
|
|||||||
# coding=utf-8
|
from llmtuner import export_model
|
||||||
# Exports the fine-tuned model.
|
|
||||||
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
|
||||||
|
|
||||||
from llmtuner.tuner import get_train_args, load_model_and_tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model_args, _, training_args, finetuning_args, _ = get_train_args()
|
export_model()
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
|
||||||
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
|
|
||||||
tokenizer.save_pretrained(training_args.output_dir)
|
|
||||||
print("model and tokenizer have been saved at:", training_args.output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
|
# Level: api, webui > chat > tuner > dsets > extras, hparams
|
||||||
|
|
||||||
|
from llmtuner.api import create_app
|
||||||
from llmtuner.chat import ChatModel
|
from llmtuner.chat import ChatModel
|
||||||
|
from llmtuner.tuner import export_model, run_exp
|
||||||
|
from llmtuner.webui import create_ui, create_web_demo
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.1.5"
|
__version__ = "0.1.7"
|
||||||
|
|||||||
@@ -0,0 +1 @@
|
|||||||
|
from llmtuner.api.app import create_app
|
||||||
|
|||||||
@@ -5,9 +5,8 @@ from contextlib import asynccontextmanager
|
|||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from llmtuner.tuner import get_infer_args
|
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.chat.stream_chat import ChatModel
|
from llmtuner.chat import ChatModel
|
||||||
from llmtuner.api.protocol import (
|
from llmtuner.api.protocol import (
|
||||||
Role,
|
Role,
|
||||||
Finish,
|
Finish,
|
||||||
@@ -48,15 +47,15 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
if request.messages[-1].role != Role.USER:
|
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
|
||||||
raise HTTPException(status_code=400, detail="Invalid request")
|
raise HTTPException(status_code=400, detail="Invalid request")
|
||||||
query = request.messages[-1].content
|
|
||||||
|
|
||||||
|
query = request.messages[-1].content
|
||||||
prev_messages = request.messages[:-1]
|
prev_messages = request.messages[:-1]
|
||||||
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
||||||
prefix = prev_messages.pop(0).content
|
system = prev_messages.pop(0).content
|
||||||
else:
|
else:
|
||||||
prefix = None
|
system = None
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
if len(prev_messages) % 2 == 0:
|
if len(prev_messages) % 2 == 0:
|
||||||
@@ -65,11 +64,11 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = predict(query, history, prefix, request)
|
generate = predict(query, history, system, request)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
response, (prompt_length, response_length) = chat_model.chat(
|
response, (prompt_length, response_length) = chat_model.chat(
|
||||||
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
usage = ChatCompletionResponseUsage(
|
usage = ChatCompletionResponseUsage(
|
||||||
@@ -86,7 +85,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||||||
|
|
||||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
|
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
|
||||||
|
|
||||||
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
|
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(role=Role.ASSISTANT),
|
delta=DeltaMessage(role=Role.ASSISTANT),
|
||||||
@@ -96,7 +95,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
|
||||||
for new_text in chat_model.stream_chat(
|
for new_text in chat_model.stream_chat(
|
||||||
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
||||||
):
|
):
|
||||||
if len(new_text) == 0:
|
if len(new_text) == 0:
|
||||||
continue
|
continue
|
||||||
@@ -122,6 +121,6 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
chat_model = ChatModel(*get_infer_args())
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||||
|
|||||||
@@ -1,44 +1,37 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
||||||
from llmtuner.extras.template import get_template
|
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||||
from llmtuner.tuner import load_model_and_tokenizer
|
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
|
||||||
|
|
||||||
|
|
||||||
class ChatModel:
|
class ChatModel:
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
self,
|
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||||
model_args: "ModelArguments",
|
|
||||||
data_args: "DataArguments",
|
|
||||||
finetuning_args: "FinetuningArguments",
|
|
||||||
generating_args: "GeneratingArguments"
|
|
||||||
) -> None:
|
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
self.model = dispatch_model(self.model)
|
self.model = dispatch_model(self.model)
|
||||||
self.template = get_template(data_args.template)
|
self.model = self.model.eval() # enable evaluation mode
|
||||||
self.source_prefix = data_args.source_prefix
|
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||||
self.generating_args = generating_args
|
self.system_prompt = data_args.system_prompt
|
||||||
|
|
||||||
def process_args(
|
def process_args(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
prefix: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
prefix = prefix or self.source_prefix
|
system = system or self.system_prompt
|
||||||
|
|
||||||
prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token)
|
prompt, _ = self.template.encode_oneturn(
|
||||||
inputs = self.tokenizer([prompt], return_tensors="pt")
|
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
||||||
inputs = inputs.to(self.model.device)
|
)
|
||||||
prompt_length = len(inputs["input_ids"][0])
|
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||||
|
prompt_length = len(input_ids[0])
|
||||||
|
|
||||||
do_sample = input_kwargs.pop("do_sample", None)
|
do_sample = input_kwargs.pop("do_sample", None)
|
||||||
temperature = input_kwargs.pop("temperature", None)
|
temperature = input_kwargs.pop("temperature", None)
|
||||||
@@ -50,12 +43,14 @@ class ChatModel:
|
|||||||
|
|
||||||
gen_kwargs = self.generating_args.to_dict()
|
gen_kwargs = self.generating_args.to_dict()
|
||||||
gen_kwargs.update(dict(
|
gen_kwargs.update(dict(
|
||||||
input_ids=inputs["input_ids"],
|
input_ids=input_ids,
|
||||||
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
|
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
|
||||||
temperature=temperature or gen_kwargs["temperature"],
|
temperature=temperature or gen_kwargs["temperature"],
|
||||||
top_p=top_p or gen_kwargs["top_p"],
|
top_p=top_p or gen_kwargs["top_p"],
|
||||||
top_k=top_k or gen_kwargs["top_k"],
|
top_k=top_k or gen_kwargs["top_k"],
|
||||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||||
|
eos_token_id=list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)),
|
||||||
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
logits_processor=get_logits_processor()
|
logits_processor=get_logits_processor()
|
||||||
))
|
))
|
||||||
|
|
||||||
@@ -74,10 +69,10 @@ class ChatModel:
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
prefix: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs
|
||||||
) -> Tuple[str, Tuple[int, int]]:
|
) -> Tuple[str, Tuple[int, int]]:
|
||||||
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
|
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
|
||||||
generation_output = self.model.generate(**gen_kwargs)
|
generation_output = self.model.generate(**gen_kwargs)
|
||||||
outputs = generation_output.tolist()[0][prompt_length:]
|
outputs = generation_output.tolist()[0][prompt_length:]
|
||||||
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
|
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
|
||||||
@@ -89,10 +84,10 @@ class ChatModel:
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
prefix: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
|
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
|
||||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||||
gen_kwargs["streamer"] = streamer
|
gen_kwargs["streamer"] = streamer
|
||||||
|
|
||||||
|
|||||||
@@ -1,48 +1,25 @@
|
|||||||
import os
|
import os
|
||||||
import hashlib
|
from typing import TYPE_CHECKING, List, Union
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
|
||||||
|
|
||||||
from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset
|
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||||
|
|
||||||
|
from llmtuner.dsets.utils import checksum, EXT2TYPE
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset
|
from datasets import Dataset, IterableDataset
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments
|
from llmtuner.hparams import ModelArguments, DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
EXT2TYPE = {
|
|
||||||
"csv": "csv",
|
|
||||||
"json": "json",
|
|
||||||
"jsonl": "json",
|
|
||||||
"txt": "text"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
|
||||||
if file_sha1 is None:
|
|
||||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if len(data_files) != 1:
|
|
||||||
logger.warning("Checksum failed: too many files.")
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(data_files[0], "rb") as f:
|
|
||||||
sha1 = hashlib.sha1(f.read()).hexdigest()
|
|
||||||
if sha1 != file_sha1:
|
|
||||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments"
|
data_args: "DataArguments"
|
||||||
) -> "Dataset":
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
max_samples = data_args.max_samples
|
max_samples = data_args.max_samples
|
||||||
all_datasets: List["Dataset"] = [] # support multiple datasets
|
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
|
||||||
|
|
||||||
for dataset_attr in data_args.dataset_list:
|
for dataset_attr in data_args.dataset_list:
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
@@ -92,12 +69,11 @@ def get_dataset(
|
|||||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||||
|
|
||||||
if dataset_attr.source_prefix: # add prefix
|
if dataset_attr.system_prompt: # add system prompt
|
||||||
features = None
|
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
features = dataset.features
|
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
|
||||||
features["prefix"] = Value(dtype="string", id=None)
|
else:
|
||||||
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features)
|
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
|
||||||
|
|
||||||
all_datasets.append(dataset)
|
all_datasets.append(dataset)
|
||||||
|
|
||||||
@@ -111,6 +87,6 @@ def get_dataset(
|
|||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||||
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
||||||
return interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)
|
return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown mixing strategy.")
|
raise ValueError("Unknown mixing strategy.")
|
||||||
|
|||||||
@@ -1,37 +1,43 @@
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
|
import tiktoken
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.template import get_template
|
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
from llmtuner.hparams import DataArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset(
|
def preprocess_dataset(
|
||||||
dataset: "Dataset",
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||||
) -> "Dataset":
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
column_names = list(dataset.column_names)
|
column_names = list(next(iter(dataset)).keys())
|
||||||
template = get_template(data_args.template)
|
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||||
|
|
||||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||||
for i in range(len(examples["prompt"])):
|
for i in range(len(examples["prompt"])):
|
||||||
query, response = examples["prompt"][i], examples["response"][i]
|
query, response = examples["prompt"][i], examples["response"][i]
|
||||||
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
||||||
history = examples["history"][i] if "history" in examples else None
|
history = examples["history"][i] if "history" in examples else None
|
||||||
prefix = examples["prefix"][i] if "prefix" in examples else None
|
system = examples["system"][i] if "system" in examples else None
|
||||||
yield query, response, history, prefix
|
yield query, response, history, system
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||||
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
|
# build grouped texts with format `X1 X2 X3 ...` (without <eos>)
|
||||||
tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False)
|
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||||
|
kwargs = dict(allowed_special="all")
|
||||||
|
else:
|
||||||
|
kwargs = dict(add_special_tokens=False)
|
||||||
|
|
||||||
|
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
|
||||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||||
block_size = data_args.max_source_length
|
block_size = data_args.max_source_length
|
||||||
@@ -42,33 +48,28 @@ def preprocess_dataset(
|
|||||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
||||||
for k, t in concatenated_examples.items()
|
for k, t in concatenated_examples.items()
|
||||||
}
|
}
|
||||||
result["labels"] = result["input_ids"].copy()
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||||
# for input with history, we build multiple input-label pairs just like:
|
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||||
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
max_length = data_args.max_source_length + data_args.max_target_length
|
max_length = data_args.max_source_length + data_args.max_target_length
|
||||||
|
|
||||||
for query, response, history, prefix in construct_example(examples):
|
for query, response, history, system in construct_example(examples):
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
|
|
||||||
for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)):
|
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
|
||||||
source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0))
|
|
||||||
target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False)
|
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length:
|
if len(source_ids) > data_args.max_source_length:
|
||||||
source_ids = source_ids[:data_args.max_source_length]
|
source_ids = source_ids[:data_args.max_source_length]
|
||||||
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
if len(target_ids) > data_args.max_target_length:
|
||||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
target_ids = target_ids[:data_args.max_target_length]
|
||||||
|
|
||||||
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
|
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
|
||||||
break
|
break
|
||||||
|
|
||||||
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
|
input_ids += source_ids + target_ids
|
||||||
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
labels += [IGNORE_INDEX] * len(source_ids) + target_ids
|
||||||
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
@@ -77,14 +78,11 @@ def preprocess_dataset(
|
|||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||||
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||||
|
|
||||||
for query, response, history, prefix in construct_example(examples):
|
for query, response, history, system in construct_example(examples):
|
||||||
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
|
source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, system)
|
||||||
|
|
||||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
|
||||||
target_ids = tokenizer.encode(text=response, add_special_tokens=True)
|
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length:
|
if len(source_ids) > data_args.max_source_length:
|
||||||
source_ids = source_ids[:data_args.max_source_length]
|
source_ids = source_ids[:data_args.max_source_length]
|
||||||
@@ -98,43 +96,39 @@ def preprocess_dataset(
|
|||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def preprocess_pairwise_dataset(examples):
|
def preprocess_pairwise_dataset(examples):
|
||||||
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||||
for query, response, history, prefix in construct_example(examples):
|
for query, response, history, system in construct_example(examples):
|
||||||
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
|
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
|
||||||
|
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
|
||||||
|
|
||||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
if len(prompt_ids) > data_args.max_source_length:
|
||||||
accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False)
|
prompt_ids = prompt_ids[:data_args.max_source_length]
|
||||||
reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False)
|
if len(chosen_ids) > data_args.max_target_length:
|
||||||
|
chosen_ids = chosen_ids[:data_args.max_target_length]
|
||||||
|
if len(rejected_ids) > data_args.max_target_length:
|
||||||
|
rejected_ids = rejected_ids[:data_args.max_target_length]
|
||||||
|
|
||||||
if len(source_ids) > data_args.max_source_length:
|
model_inputs["prompt_ids"].append(prompt_ids)
|
||||||
source_ids = source_ids[:data_args.max_source_length]
|
model_inputs["chosen_ids"].append(chosen_ids)
|
||||||
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
model_inputs["rejected_ids"].append(rejected_ids)
|
||||||
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
|
||||||
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
|
||||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
|
||||||
|
|
||||||
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
|
|
||||||
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
model_inputs["accept_ids"].append(accept_ids)
|
|
||||||
model_inputs["reject_ids"].append(reject_ids)
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def print_supervised_dataset_example(example):
|
def print_supervised_dataset_example(example):
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
print("label_ids:\n{}".format(example["labels"]))
|
print("label_ids:\n{}".format(example["labels"]))
|
||||||
print("labels:\n{}".format(
|
print("labels:\n{}".format(tokenizer.decode([
|
||||||
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
|
token_id if token_id != IGNORE_INDEX else tokenizer.pad_token_id for token_id in example["labels"]
|
||||||
skip_special_tokens=False)
|
], skip_special_tokens=False)))
|
||||||
))
|
|
||||||
|
|
||||||
def print_pairwise_dataset_example(example):
|
def print_pairwise_dataset_example(example):
|
||||||
print("accept_ids:\n{}".format(example["accept_ids"]))
|
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||||
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
|
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||||
print("reject_ids:\n{}".format(example["reject_ids"]))
|
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
|
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
||||||
|
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||||
|
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
def print_unsupervised_dataset_example(example):
|
def print_unsupervised_dataset_example(example):
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
@@ -173,8 +167,5 @@ def preprocess_dataset(
|
|||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.streaming:
|
|
||||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
|
|
||||||
|
|
||||||
print_function(next(iter(dataset)))
|
print_function(next(iter(dataset)))
|
||||||
return dataset
|
return dataset
|
||||||
|
|||||||
@@ -1,15 +1,59 @@
|
|||||||
from typing import TYPE_CHECKING, Dict
|
import hashlib
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset
|
from datasets import Dataset, IterableDataset
|
||||||
|
from transformers import TrainingArguments
|
||||||
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]:
|
logger = get_logger(__name__)
|
||||||
if do_train:
|
|
||||||
if dev_ratio > 1e-6: # Split the dataset
|
|
||||||
dataset = dataset.train_test_split(test_size=dev_ratio)
|
EXT2TYPE = {
|
||||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
"csv": "csv",
|
||||||
|
"json": "json",
|
||||||
|
"jsonl": "json",
|
||||||
|
"txt": "text"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||||
|
if file_sha1 is None:
|
||||||
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(data_files) != 1:
|
||||||
|
logger.warning("Checksum failed: too many files.")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(data_files[0], "rb") as f:
|
||||||
|
sha1 = hashlib.sha1(f.read()).hexdigest()
|
||||||
|
if sha1 != file_sha1:
|
||||||
|
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def split_dataset(
|
||||||
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "TrainingArguments"
|
||||||
|
) -> Dict[str, "Dataset"]:
|
||||||
|
if training_args.do_train:
|
||||||
|
if data_args.val_size > 1e-6: # Split the dataset
|
||||||
|
if data_args.streaming:
|
||||||
|
val_set = dataset.take(int(data_args.val_size))
|
||||||
|
train_set = dataset.skip(int(data_args.val_size))
|
||||||
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||||
|
return {"train_dataset": train_set, "eval_dataset": val_set}
|
||||||
|
else:
|
||||||
|
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||||
|
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
|
||||||
|
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||||
else:
|
else:
|
||||||
|
if data_args.streaming:
|
||||||
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||||
return {"train_dataset": dataset}
|
return {"train_dataset": dataset}
|
||||||
else: # do_eval or do_predict
|
else: # do_eval or do_predict
|
||||||
return {"eval_dataset": dataset}
|
return {"eval_dataset": dataset}
|
||||||
|
|||||||
@@ -5,67 +5,124 @@ from typing import TYPE_CHECKING
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
from transformers.trainer_utils import has_length
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
from transformers import TrainingArguments, TrainerState, TrainerControl
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
class LogCallback(TrainerCallback):
|
||||||
|
|
||||||
def __init__(self, runner=None):
|
def __init__(self, runner=None):
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
|
self.in_training = False
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.tracker = {}
|
self.cur_steps = 0
|
||||||
|
self.max_steps = 0
|
||||||
|
self.elapsed_time = ""
|
||||||
|
self.remaining_time = ""
|
||||||
|
|
||||||
|
def timing(self):
|
||||||
|
cur_time = time.time()
|
||||||
|
elapsed_time = cur_time - self.start_time
|
||||||
|
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
|
||||||
|
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
|
||||||
|
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
||||||
|
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
||||||
|
|
||||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the beginning of training.
|
Event called at the beginning of training.
|
||||||
"""
|
"""
|
||||||
self.start_time = time.time()
|
if state.is_local_process_zero:
|
||||||
|
self.in_training = True
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.max_steps = state.max_steps
|
||||||
|
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)):
|
||||||
|
logger.warning("Previous log file in this folder will be deleted.")
|
||||||
|
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
||||||
|
|
||||||
def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
Event called at the end of training.
|
||||||
might take several inputs.
|
|
||||||
"""
|
"""
|
||||||
if self.runner is not None and self.runner.aborted:
|
if state.is_local_process_zero:
|
||||||
control.should_epoch_stop = True
|
self.in_training = False
|
||||||
control.should_training_stop = True
|
self.cur_steps = 0
|
||||||
|
self.max_steps = 0
|
||||||
|
|
||||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of an substep during gradient accumulation.
|
Event called at the end of an substep during gradient accumulation.
|
||||||
"""
|
"""
|
||||||
if self.runner is not None and self.runner.aborted:
|
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
|
||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called at the end of a training step.
|
||||||
|
"""
|
||||||
|
if state.is_local_process_zero:
|
||||||
|
self.cur_steps = state.global_step
|
||||||
|
self.timing()
|
||||||
|
if self.runner is not None and self.runner.aborted:
|
||||||
|
control.should_epoch_stop = True
|
||||||
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called after an evaluation phase.
|
||||||
|
"""
|
||||||
|
if state.is_local_process_zero and not self.in_training:
|
||||||
|
self.cur_steps = 0
|
||||||
|
self.max_steps = 0
|
||||||
|
|
||||||
|
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called after a successful prediction.
|
||||||
|
"""
|
||||||
|
if state.is_local_process_zero and not self.in_training:
|
||||||
|
self.cur_steps = 0
|
||||||
|
self.max_steps = 0
|
||||||
|
|
||||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
||||||
r"""
|
r"""
|
||||||
Event called after logging the last logs.
|
Event called after logging the last logs.
|
||||||
"""
|
"""
|
||||||
if not state.is_world_process_zero:
|
if not state.is_local_process_zero:
|
||||||
return
|
return
|
||||||
|
|
||||||
cur_time = time.time()
|
logs = dict(
|
||||||
cur_steps = state.log_history[-1].get("step")
|
current_steps=self.cur_steps,
|
||||||
elapsed_time = cur_time - self.start_time
|
total_steps=self.max_steps,
|
||||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
loss=state.log_history[-1].get("loss", None),
|
||||||
remaining_steps = state.max_steps - cur_steps
|
eval_loss=state.log_history[-1].get("eval_loss", None),
|
||||||
remaining_time = remaining_steps * avg_time_per_step
|
predict_loss=state.log_history[-1].get("predict_loss", None),
|
||||||
self.tracker = {
|
reward=state.log_history[-1].get("reward", None),
|
||||||
"current_steps": cur_steps,
|
learning_rate=state.log_history[-1].get("learning_rate", None),
|
||||||
"total_steps": state.max_steps,
|
epoch=state.log_history[-1].get("epoch", None),
|
||||||
"loss": state.log_history[-1].get("loss", None),
|
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||||
"eval_loss": state.log_history[-1].get("eval_loss", None),
|
elapsed_time=self.elapsed_time,
|
||||||
"predict_loss": state.log_history[-1].get("predict_loss", None),
|
remaining_time=self.remaining_time
|
||||||
"reward": state.log_history[-1].get("reward", None),
|
)
|
||||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
|
||||||
"epoch": state.log_history[-1].get("epoch", None),
|
|
||||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
|
||||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
|
||||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
|
||||||
}
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||||
f.write(json.dumps(self.tracker) + "\n")
|
f.write(json.dumps(logs) + "\n")
|
||||||
|
|
||||||
|
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
|
r"""
|
||||||
|
Event called after a prediction step.
|
||||||
|
"""
|
||||||
|
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
||||||
|
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
|
||||||
|
if self.max_steps == 0:
|
||||||
|
self.max_steps = len(eval_dataloader)
|
||||||
|
self.cur_steps += 1
|
||||||
|
self.timing()
|
||||||
|
|||||||
@@ -1,13 +1,23 @@
|
|||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||||
|
|
||||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||||
|
|
||||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||||
|
|
||||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
|
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora"]
|
METHODS = ["full", "freeze", "lora"]
|
||||||
|
|
||||||
|
STAGES = [
|
||||||
|
"SFT",
|
||||||
|
"Reward Modeling",
|
||||||
|
"PPO",
|
||||||
|
"DPO",
|
||||||
|
"Pre-Training"
|
||||||
|
]
|
||||||
|
|
||||||
SUPPORTED_MODELS = {
|
SUPPORTED_MODELS = {
|
||||||
"LLaMA-7B": "huggyllama/llama-7b",
|
"LLaMA-7B": "huggyllama/llama-7b",
|
||||||
"LLaMA-13B": "huggyllama/llama-13b",
|
"LLaMA-13B": "huggyllama/llama-13b",
|
||||||
@@ -19,29 +29,50 @@ SUPPORTED_MODELS = {
|
|||||||
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
||||||
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
|
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
|
||||||
|
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
|
||||||
|
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
|
||||||
|
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
|
||||||
|
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
|
||||||
"BLOOM-560M": "bigscience/bloom-560m",
|
"BLOOM-560M": "bigscience/bloom-560m",
|
||||||
"BLOOM-3B": "bigscience/bloom-3b",
|
"BLOOM-3B": "bigscience/bloom-3b",
|
||||||
"BLOOM-7B1": "bigscience/bloom-7b1",
|
"BLOOM-7B1": "bigscience/bloom-7b1",
|
||||||
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
||||||
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
||||||
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
|
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
|
||||||
"Falcon-7B-Base": "tiiuae/falcon-7b",
|
"Falcon-7B": "tiiuae/falcon-7b",
|
||||||
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
|
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
|
||||||
"Falcon-40B-Base": "tiiuae/falcon-40b",
|
"Falcon-40B": "tiiuae/falcon-40b",
|
||||||
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
||||||
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
|
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
|
||||||
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
|
"Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
|
||||||
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
|
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
|
||||||
"InternLM-7B-Base": "internlm/internlm-7b",
|
"InternLM-7B": "internlm/internlm-7b",
|
||||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
|
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
|
||||||
|
"Qwen-7B": "Qwen/Qwen-7B",
|
||||||
|
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||||
|
"XVERSE-13B": "xverse/XVERSE-13B",
|
||||||
|
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_MODULE = {
|
DEFAULT_MODULE = {
|
||||||
"LLaMA": "q_proj,v_proj",
|
"LLaMA": "q_proj,v_proj",
|
||||||
"LLaMA2": "q_proj,v_proj",
|
"LLaMA2": "q_proj,v_proj",
|
||||||
|
"ChineseLLaMA2": "q_proj,v_proj",
|
||||||
"BLOOM": "query_key_value",
|
"BLOOM": "query_key_value",
|
||||||
"BLOOMZ": "query_key_value",
|
"BLOOMZ": "query_key_value",
|
||||||
"Falcon": "query_key_value",
|
"Falcon": "query_key_value",
|
||||||
"Baichuan": "W_pack",
|
"Baichuan": "W_pack",
|
||||||
"InternLM": "q_proj,v_proj"
|
"InternLM": "q_proj,v_proj",
|
||||||
|
"Qwen": "c_attn",
|
||||||
|
"XVERSE": "q_proj,v_proj",
|
||||||
|
"ChatGLM2": "query_key_value"
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_TEMPLATE = {
|
||||||
|
"LLaMA2": "llama2",
|
||||||
|
"ChineseLLaMA2": "llama2_zh",
|
||||||
|
"Baichuan": "baichuan",
|
||||||
|
"InternLM": "intern",
|
||||||
|
"Qwen": "chatml",
|
||||||
|
"ChatGLM2": "chatglm2"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ class LoggerHandler(logging.Handler):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.log = ""
|
self.log = ""
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.log = ""
|
||||||
|
|
||||||
def emit(self, record):
|
def emit(self, record):
|
||||||
if record.name == "httpx":
|
if record.name == "httpx":
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||||
from transformers.generation.utils import LogitsProcessorList
|
|
||||||
from transformers.generation.logits_process import LogitsProcessor
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
|
|
||||||
@@ -30,19 +28,9 @@ class AverageMeter:
|
|||||||
self.avg = self.sum / self.count
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
# Avoids runtime error in model.generate(do_sample=True).
|
|
||||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
|
||||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
|
||||||
scores.zero_()
|
|
||||||
scores[..., 0] = 1.0
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> LogitsProcessorList:
|
def get_logits_processor() -> LogitsProcessorList:
|
||||||
logits_processor = LogitsProcessorList()
|
logits_processor = LogitsProcessorList()
|
||||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||||
return logits_processor
|
return logits_processor
|
||||||
|
|
||||||
|
|
||||||
@@ -77,7 +65,6 @@ def prepare_model_for_training(
|
|||||||
use_gradient_checkpointing: Optional[bool] = True,
|
use_gradient_checkpointing: Optional[bool] = True,
|
||||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||||
) -> "PreTrainedModel":
|
) -> "PreTrainedModel":
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
@@ -94,9 +81,6 @@ def prepare_model_for_training(
|
|||||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||||
|
|
||||||
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||||
if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"):
|
|
||||||
model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728)
|
|
||||||
|
|
||||||
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
||||||
input_dtype = output_layer.weight.dtype
|
input_dtype = output_layer.weight.dtype
|
||||||
|
|
||||||
@@ -124,6 +108,9 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
|||||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||||
"""
|
"""
|
||||||
|
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
||||||
|
return model
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model
|
||||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||||
|
|||||||
@@ -1,92 +1,230 @@
|
|||||||
from typing import Dict, List, Optional, Tuple
|
import tiktoken
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Template:
|
class Template:
|
||||||
|
|
||||||
prefix: str
|
prefix: List[Union[str, Dict[str, str]]]
|
||||||
prompt: str
|
prompt: List[Union[str, Dict[str, str]]]
|
||||||
sep: str
|
system: str
|
||||||
|
sep: List[Union[str, Dict[str, str]]]
|
||||||
|
stop_words: List[str]
|
||||||
use_history: bool
|
use_history: bool
|
||||||
|
|
||||||
def get_prompt(
|
def encode_oneturn(
|
||||||
self,
|
self,
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
query: str,
|
query: str,
|
||||||
|
resp: str,
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
prefix: Optional[str] = "",
|
system: Optional[str] = None
|
||||||
eos_token: Optional[str] = "</s>"
|
) -> Tuple[List[int], List[int]]:
|
||||||
) -> str:
|
|
||||||
r"""
|
r"""
|
||||||
Returns a string containing prompt without response.
|
Returns a single pair of token ids representing prompt and response respectively.
|
||||||
"""
|
"""
|
||||||
return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix)))
|
system, history = self._format(query, resp, history, system)
|
||||||
|
encoded_pairs = self._encode(tokenizer, system, history)
|
||||||
|
prompt_ids = []
|
||||||
|
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||||
|
prompt_ids = prompt_ids + query_ids + resp_ids
|
||||||
|
prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
|
||||||
|
return prompt_ids, answer_ids
|
||||||
|
|
||||||
def get_dialog(
|
def encode_multiturn(
|
||||||
|
self,
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
query: str,
|
||||||
|
resp: str,
|
||||||
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
|
system: Optional[str] = None
|
||||||
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
|
r"""
|
||||||
|
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||||
|
"""
|
||||||
|
system, history = self._format(query, resp, history, system)
|
||||||
|
encoded_pairs = self._encode(tokenizer, system, history)
|
||||||
|
return encoded_pairs
|
||||||
|
|
||||||
|
def _format(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
resp: str,
|
resp: str,
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
history: Optional[List[Tuple[str, str]]] = None,
|
||||||
prefix: Optional[str] = ""
|
system: Optional[str] = None
|
||||||
) -> List[Tuple[str, str]]:
|
) -> Tuple[str, List[Tuple[str, str]]]:
|
||||||
r"""
|
r"""
|
||||||
Returns a list containing prompt-response pairs.
|
Aligns inputs to the standard format.
|
||||||
"""
|
"""
|
||||||
result = self._format_example(query, history, prefix)
|
system = system or self.system # use system if provided
|
||||||
result[-1][-1] = resp
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _format_example(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
|
||||||
prefix: Optional[str] = ""
|
|
||||||
) -> List[Tuple[str, str]]:
|
|
||||||
prefix = prefix or self.prefix # use prefix if provided
|
|
||||||
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
|
|
||||||
history = history if (history and self.use_history) else []
|
history = history if (history and self.use_history) else []
|
||||||
history = history + [(query, "")]
|
history = history + [(query, resp)]
|
||||||
return [
|
return system, history
|
||||||
[(self.sep if i else prefix) + self.prompt.format(query=q), r]
|
|
||||||
for i, (q, r) in enumerate(history)
|
def _get_special_ids(
|
||||||
]
|
self,
|
||||||
|
tokenizer: "PreTrainedTokenizer"
|
||||||
|
) -> Tuple[List[int], List[int]]:
|
||||||
|
if (
|
||||||
|
tokenizer.bos_token_id is not None
|
||||||
|
and getattr(tokenizer, "add_bos_token", True)
|
||||||
|
): # baichuan-13b has no bos token
|
||||||
|
bos_ids = [tokenizer.bos_token_id]
|
||||||
|
else:
|
||||||
|
bos_ids = [] # bos token is optional
|
||||||
|
|
||||||
|
if tokenizer.eos_token_id is not None:
|
||||||
|
eos_ids = [tokenizer.eos_token_id]
|
||||||
|
else:
|
||||||
|
raise ValueError("EOS token is required.")
|
||||||
|
|
||||||
|
return bos_ids, eos_ids
|
||||||
|
|
||||||
|
def _encode(
|
||||||
|
self,
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
system: str,
|
||||||
|
history: List[Tuple[str, str]]
|
||||||
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
|
r"""
|
||||||
|
Encodes formatted inputs to pairs of token ids.
|
||||||
|
Turn 0: bos + prefix + sep + query resp + eos
|
||||||
|
Turn t: sep + bos + query resp + eos
|
||||||
|
"""
|
||||||
|
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
||||||
|
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
||||||
|
encoded_pairs = []
|
||||||
|
for turn_idx, (query, resp) in enumerate(history):
|
||||||
|
if turn_idx == 0:
|
||||||
|
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
|
||||||
|
if len(prefix_ids) != 0: # has prefix
|
||||||
|
prefix_ids = bos_ids + prefix_ids + sep_ids
|
||||||
|
else:
|
||||||
|
prefix_ids = bos_ids
|
||||||
|
else:
|
||||||
|
prefix_ids = sep_ids + bos_ids
|
||||||
|
|
||||||
|
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
|
||||||
|
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
||||||
|
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
|
||||||
|
return encoded_pairs
|
||||||
|
|
||||||
|
def _convert_inputs_to_ids(
|
||||||
|
self,
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
context: List[Union[str, Dict[str, str]]],
|
||||||
|
system: Optional[str] = None,
|
||||||
|
query: Optional[str] = None,
|
||||||
|
idx: Optional[str] = None
|
||||||
|
) -> List[int]:
|
||||||
|
r"""
|
||||||
|
Converts context to token ids.
|
||||||
|
"""
|
||||||
|
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||||
|
kwargs = dict(allowed_special="all")
|
||||||
|
else:
|
||||||
|
kwargs = dict(add_special_tokens=False)
|
||||||
|
|
||||||
|
token_ids = []
|
||||||
|
for elem in context:
|
||||||
|
if isinstance(elem, str):
|
||||||
|
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
|
||||||
|
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
|
||||||
|
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
|
||||||
|
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
|
||||||
|
elif isinstance(elem, dict):
|
||||||
|
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Llama2Template(Template):
|
class Llama2Template(Template):
|
||||||
|
|
||||||
def _format_example(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
query: str,
|
tokenizer: "PreTrainedTokenizer",
|
||||||
history: Optional[List[Tuple[str, str]]] = None,
|
system: str,
|
||||||
prefix: Optional[str] = ""
|
history: List[Tuple[str, str]]
|
||||||
) -> List[Tuple[str, str]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
prefix = prefix or self.prefix # use prefix if provided
|
r"""
|
||||||
prefix = prefix if prefix.startswith("<<SYS>>") else "<<SYS>>\n{}\n<</SYS>>\n\n".format(prefix)
|
Encodes formatted inputs to pairs of token ids.
|
||||||
history = history if (history and self.use_history) else []
|
Turn 0: bos + prefix + query resp + eos
|
||||||
history = history + [(query, "")]
|
Turn t: bos + query resp + eos
|
||||||
return [
|
"""
|
||||||
[(self.sep if i else "") + self.prompt.format(query=(q if i else prefix + q)), r]
|
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
||||||
for i, (q, r) in enumerate(history)
|
encoded_pairs = []
|
||||||
]
|
for turn_idx, (query, resp) in enumerate(history):
|
||||||
|
if turn_idx == 0: # llama2 template has no sep_ids
|
||||||
|
query = self.prefix[0].replace("{{system}}", system) + query
|
||||||
|
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
||||||
|
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
||||||
|
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
|
||||||
|
return encoded_pairs
|
||||||
|
|
||||||
|
|
||||||
templates: Dict[str, Template] = {}
|
templates: Dict[str, Template] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
|
def register_template(
|
||||||
template_class = Llama2Template if name == "llama2" else Template
|
name: str,
|
||||||
|
prefix: List[Union[str, Dict[str, str]]],
|
||||||
|
prompt: List[Union[str, Dict[str, str]]],
|
||||||
|
system: str,
|
||||||
|
sep: List[Union[str, Dict[str, str]]],
|
||||||
|
stop_words: Optional[List[str]] = [],
|
||||||
|
use_history: Optional[bool] = True
|
||||||
|
) -> None:
|
||||||
|
template_class = Llama2Template if "llama2" in name else Template
|
||||||
templates[name] = template_class(
|
templates[name] = template_class(
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
system=system,
|
||||||
sep=sep,
|
sep=sep,
|
||||||
|
stop_words=stop_words,
|
||||||
use_history=use_history
|
use_history=use_history
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_template(name: str) -> Template:
|
def get_template_and_fix_tokenizer(
|
||||||
|
name: str,
|
||||||
|
tokenizer: "PreTrainedTokenizer"
|
||||||
|
) -> Template:
|
||||||
template = templates.get(name, None)
|
template = templates.get(name, None)
|
||||||
assert template is not None, "Template {} does not exist.".format(name)
|
assert template is not None, "Template {} does not exist.".format(name)
|
||||||
|
|
||||||
|
additional_special_tokens = template.stop_words
|
||||||
|
if len(template.stop_words): # inplace method
|
||||||
|
if tokenizer.eos_token_id is not None:
|
||||||
|
additional_special_tokens.append(tokenizer.eos_token)
|
||||||
|
|
||||||
|
tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token
|
||||||
|
additional_special_tokens.pop(0)
|
||||||
|
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
||||||
|
|
||||||
|
if tokenizer.eos_token_id is None:
|
||||||
|
tokenizer.eos_token = "<|endoftext|>"
|
||||||
|
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||||
|
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
if tokenizer.unk_token_id is not None:
|
||||||
|
tokenizer.pad_token = tokenizer.unk_token
|
||||||
|
else:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||||
|
|
||||||
|
tokenizer.add_special_tokens(dict(additional_special_tokens=additional_special_tokens))
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
@@ -95,9 +233,12 @@ Supports language model inference without histories.
|
|||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="vanilla",
|
name="vanilla",
|
||||||
prefix="",
|
prefix=[],
|
||||||
prompt="{query}",
|
prompt=[
|
||||||
sep="",
|
"{{query}}"
|
||||||
|
],
|
||||||
|
system="",
|
||||||
|
sep=[],
|
||||||
use_history=False
|
use_history=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -107,11 +248,19 @@ Default template.
|
|||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="default",
|
name="default",
|
||||||
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
prefix=[
|
||||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
"{{system}}"
|
||||||
prompt="Human: {query}\nAssistant: ",
|
],
|
||||||
sep="\n",
|
prompt=[
|
||||||
use_history=True
|
"Human: {{query}}\nAssistant: "
|
||||||
|
],
|
||||||
|
system=(
|
||||||
|
"A chat between a curious user and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||||
|
),
|
||||||
|
sep=[
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -122,17 +271,40 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
|
|||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="llama2",
|
name="llama2",
|
||||||
prefix="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
|
prefix=[
|
||||||
"Always answer as helpfully as possible, while being safe. "
|
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
||||||
"Your answers should not include any harmful, unethical, "
|
],
|
||||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
prompt=[
|
||||||
"Please ensure that your responses are socially unbiased and positive in nature.\n"
|
"[INST] {{query}} [/INST] "
|
||||||
"If a question does not make any sense, or is not factually coherent, "
|
],
|
||||||
"explain why instead of answering something not correct. "
|
system=(
|
||||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
|
"You are a helpful, respectful and honest assistant. "
|
||||||
prompt="[INST] {query} [/INST] ",
|
"Always answer as helpfully as possible, while being safe. "
|
||||||
sep="<s>",
|
"Your answers should not include any harmful, unethical, "
|
||||||
use_history=True
|
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||||
|
"Please ensure that your responses are socially unbiased and positive in nature.\n"
|
||||||
|
"If a question does not make any sense, or is not factually coherent, "
|
||||||
|
"explain why instead of answering something not correct. "
|
||||||
|
"If you don't know the answer to a question, please don't share false information."
|
||||||
|
),
|
||||||
|
sep=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
r"""
|
||||||
|
Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
|
||||||
|
https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
|
||||||
|
"""
|
||||||
|
register_template(
|
||||||
|
name="llama2_zh",
|
||||||
|
prefix=[
|
||||||
|
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
||||||
|
],
|
||||||
|
prompt=[
|
||||||
|
"[INST] {{query}} [/INST] "
|
||||||
|
],
|
||||||
|
system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
||||||
|
sep=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -142,11 +314,19 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
|||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="alpaca",
|
name="alpaca",
|
||||||
prefix="Below is an instruction that describes a task. "
|
prefix=[
|
||||||
"Write a response that appropriately completes the request.",
|
"{{system}}"
|
||||||
prompt="### Instruction:\n{query}\n\n### Response:\n",
|
],
|
||||||
sep="\n\n",
|
prompt=[
|
||||||
use_history=True
|
"### Instruction:\n{{query}}\n\n### Response:\n"
|
||||||
|
],
|
||||||
|
system=(
|
||||||
|
"Below is an instruction that describes a task. "
|
||||||
|
"Write a response that appropriately completes the request."
|
||||||
|
),
|
||||||
|
sep=[
|
||||||
|
"\n\n"
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -156,11 +336,17 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
|
|||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="vicuna",
|
name="vicuna",
|
||||||
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
prefix=[
|
||||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
"{{system}}"
|
||||||
prompt="USER: {query} ASSISTANT: ",
|
],
|
||||||
sep="",
|
prompt=[
|
||||||
use_history=True
|
"USER: {{query}} ASSISTANT: "
|
||||||
|
],
|
||||||
|
system=(
|
||||||
|
"A chat between a curious user and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||||
|
),
|
||||||
|
sep=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -169,10 +355,16 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
|||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="belle",
|
name="belle",
|
||||||
prefix="",
|
prefix=[
|
||||||
prompt="Human: {query}\n\nBelle: ",
|
"{{system}}"
|
||||||
sep="\n\n",
|
],
|
||||||
use_history=True
|
prompt=[
|
||||||
|
"Human: {{query}}\n\nBelle: "
|
||||||
|
],
|
||||||
|
system="",
|
||||||
|
sep=[
|
||||||
|
"\n\n"
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -181,10 +373,16 @@ Supports: https://github.com/CVI-SZU/Linly
|
|||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="linly",
|
name="linly",
|
||||||
prefix="",
|
prefix=[
|
||||||
prompt="User: {query}\nBot: ",
|
"{{system}}"
|
||||||
sep="\n",
|
],
|
||||||
use_history=True
|
prompt=[
|
||||||
|
"User: {{query}}\nBot: "
|
||||||
|
],
|
||||||
|
system="",
|
||||||
|
sep=[
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -193,10 +391,16 @@ Supports: https://github.com/Neutralzz/BiLLa
|
|||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="billa",
|
name="billa",
|
||||||
prefix="",
|
prefix=[
|
||||||
prompt="Human: {query}\nAssistant: ",
|
"{{system}}"
|
||||||
sep="\n",
|
],
|
||||||
use_history=True
|
prompt=[
|
||||||
|
"Human: {{query}}\nAssistant: "
|
||||||
|
],
|
||||||
|
system="",
|
||||||
|
sep=[
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -205,10 +409,19 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
|
|||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="ziya",
|
name="ziya",
|
||||||
prefix="",
|
prefix=[
|
||||||
prompt="<human>:{query}\n<bot>:",
|
"{{system}}"
|
||||||
sep="\n",
|
],
|
||||||
use_history=True
|
prompt=[
|
||||||
|
{"token": "<human>"},
|
||||||
|
":{{query}}\n",
|
||||||
|
{"token": "<bot>"},
|
||||||
|
":"
|
||||||
|
],
|
||||||
|
system="",
|
||||||
|
sep=[
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -217,11 +430,19 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b
|
|||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="aquila",
|
name="aquila",
|
||||||
prefix="A chat between a curious human and an artificial intelligence assistant. "
|
prefix=[
|
||||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
"{{system}}"
|
||||||
prompt="Human: {query}###Assistant: ",
|
],
|
||||||
sep="###",
|
prompt=[
|
||||||
use_history=True
|
"Human: {{query}}###Assistant: "
|
||||||
|
],
|
||||||
|
system=(
|
||||||
|
"A chat between a curious human and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||||
|
),
|
||||||
|
sep=[
|
||||||
|
"###"
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -230,10 +451,22 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b
|
|||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="intern",
|
name="intern",
|
||||||
prefix="",
|
prefix=[
|
||||||
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
|
"{{system}}"
|
||||||
sep="<eoa>\n",
|
],
|
||||||
use_history=True
|
prompt=[
|
||||||
|
"<|User|>:{{query}}",
|
||||||
|
{"token": "<eoh>"},
|
||||||
|
"\n<|Bot|>:"
|
||||||
|
],
|
||||||
|
system="",
|
||||||
|
sep=[
|
||||||
|
"\n"
|
||||||
|
],
|
||||||
|
stop_words=[
|
||||||
|
"</s>", # internlm cannot replace eos token
|
||||||
|
"<eoa>"
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -242,10 +475,19 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
|||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="baichuan",
|
name="baichuan",
|
||||||
prefix="",
|
prefix=[
|
||||||
prompt="<reserved_102>{query}<reserved_103>",
|
"{{system}}",
|
||||||
sep="",
|
{"token": "<reserved_102>"} # user token
|
||||||
use_history=True
|
],
|
||||||
|
prompt=[
|
||||||
|
"{{query}}",
|
||||||
|
{"token": "<reserved_103>"} # assistant token
|
||||||
|
],
|
||||||
|
system="",
|
||||||
|
sep=[],
|
||||||
|
stop_words=[
|
||||||
|
"<reserved_102>" # user token
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -255,8 +497,71 @@ Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
|
|||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="starchat",
|
name="starchat",
|
||||||
prefix="<|system|>\n",
|
prefix=[
|
||||||
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
|
{"token": "<|system|>"},
|
||||||
sep="<|end|>\n",
|
"\n{{system}}",
|
||||||
use_history=True
|
{"token": "<|end|>"}
|
||||||
|
],
|
||||||
|
prompt=[
|
||||||
|
{"token": "<|user|>"},
|
||||||
|
"\n{{query}}",
|
||||||
|
{"token": "<|end|>"},
|
||||||
|
"\n",
|
||||||
|
{"token": "<|assistant|>"}
|
||||||
|
],
|
||||||
|
system="",
|
||||||
|
sep=[
|
||||||
|
"\n"
|
||||||
|
],
|
||||||
|
stop_words=[
|
||||||
|
"<|end|>"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
r"""
|
||||||
|
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
|
||||||
|
"""
|
||||||
|
register_template(
|
||||||
|
name="chatml",
|
||||||
|
prefix=[
|
||||||
|
{"token": "<|im_start|>"},
|
||||||
|
"system\n{{system}}",
|
||||||
|
{"token": "<|im_end|>"}
|
||||||
|
],
|
||||||
|
prompt=[
|
||||||
|
{"token": "<|im_start|>"},
|
||||||
|
"user\n{{query}}",
|
||||||
|
{"token": "<|im_end|>"},
|
||||||
|
"\n",
|
||||||
|
{"token": "<|im_start|>"},
|
||||||
|
"assistant\n"
|
||||||
|
],
|
||||||
|
system="You are a helpful assistant.",
|
||||||
|
sep=[
|
||||||
|
"\n"
|
||||||
|
],
|
||||||
|
stop_words=[
|
||||||
|
"<|im_end|>"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
r"""
|
||||||
|
Supports: https://huggingface.co/THUDM/chatglm2-6b
|
||||||
|
"""
|
||||||
|
register_template(
|
||||||
|
name="chatglm2",
|
||||||
|
prefix=[
|
||||||
|
{"token": "[gMASK]"},
|
||||||
|
{"token": "sop"},
|
||||||
|
"{{system}}"
|
||||||
|
],
|
||||||
|
prompt=[
|
||||||
|
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
|
||||||
|
],
|
||||||
|
system="",
|
||||||
|
sep=[
|
||||||
|
"\n\n"
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ class DatasetAttr:
|
|||||||
load_from: str
|
load_from: str
|
||||||
dataset_name: Optional[str] = None
|
dataset_name: Optional[str] = None
|
||||||
dataset_sha1: Optional[str] = None
|
dataset_sha1: Optional[str] = None
|
||||||
source_prefix: Optional[str] = None
|
system_prompt: Optional[str] = None
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return self.dataset_name
|
return self.dataset_name
|
||||||
@@ -24,7 +24,7 @@ class DatasetAttr:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataArguments:
|
class DataArguments:
|
||||||
"""
|
r"""
|
||||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||||
"""
|
"""
|
||||||
template: str = field(
|
template: str = field(
|
||||||
@@ -54,6 +54,10 @@ class DataArguments:
|
|||||||
default="concat",
|
default="concat",
|
||||||
metadata={"help": "Strategy to use in dataset mixing."}
|
metadata={"help": "Strategy to use in dataset mixing."}
|
||||||
)
|
)
|
||||||
|
interleave_probs: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
|
||||||
|
)
|
||||||
overwrite_cache: Optional[bool] = field(
|
overwrite_cache: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
||||||
@@ -82,13 +86,13 @@ class DataArguments:
|
|||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
||||||
)
|
)
|
||||||
source_prefix: Optional[str] = field(
|
system_prompt: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
|
metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
|
||||||
)
|
)
|
||||||
dev_ratio: Optional[float] = field(
|
val_size: Optional[float] = field(
|
||||||
default=0,
|
default=0,
|
||||||
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_for_training(self): # support mixing multiple datasets
|
def init_for_training(self): # support mixing multiple datasets
|
||||||
@@ -96,12 +100,12 @@ class DataArguments:
|
|||||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
|
|
||||||
if self.source_prefix is not None:
|
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
||||||
prefix_list = self.source_prefix.split("|")
|
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
|
||||||
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
|
assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
|
||||||
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
|
|
||||||
else:
|
if self.interleave_probs is not None:
|
||||||
prefix_list = [None] * len(dataset_names)
|
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
|
||||||
|
|
||||||
self.dataset_list: List[DatasetAttr] = []
|
self.dataset_list: List[DatasetAttr] = []
|
||||||
for i, name in enumerate(dataset_names):
|
for i, name in enumerate(dataset_names):
|
||||||
@@ -119,12 +123,11 @@ class DataArguments:
|
|||||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_attr.source_prefix = prefix_list[i]
|
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
if "columns" in dataset_info[name]:
|
||||||
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
|
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
|
||||||
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
||||||
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
||||||
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
||||||
|
|
||||||
|
dataset_attr.system_prompt = prompt_list[i]
|
||||||
self.dataset_list.append(dataset_attr)
|
self.dataset_list.append(dataset_attr)
|
||||||
|
|||||||
@@ -5,32 +5,36 @@ from dataclasses import asdict, dataclass, field
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FinetuningArguments:
|
class FinetuningArguments:
|
||||||
"""
|
r"""
|
||||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||||
"""
|
"""
|
||||||
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
|
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "Which fine-tuning method to use."}
|
metadata={"help": "Which fine-tuning method to use."}
|
||||||
)
|
)
|
||||||
num_hidden_layers: Optional[int] = field(
|
num_hidden_layers: Optional[int] = field(
|
||||||
default=32,
|
default=32,
|
||||||
metadata={"help": "Number of decoder blocks in the model. \
|
metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
|
||||||
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
||||||
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
|
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
|
||||||
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
||||||
Falcon choices: [\"32\", \"60\"], \
|
Falcon choices: [\"32\", \"60\"], \
|
||||||
Baichuan choices: [\"32\", \"40\"]"}
|
Baichuan choices: [\"32\", \"40\"] \
|
||||||
|
Qwen choices: [\"32\"], \
|
||||||
|
XVERSE choices: [\"40\"]"}
|
||||||
)
|
)
|
||||||
num_layer_trainable: Optional[int] = field(
|
num_layer_trainable: Optional[int] = field(
|
||||||
default=3,
|
default=3,
|
||||||
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
|
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
||||||
)
|
)
|
||||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
||||||
default="mlp",
|
default="mlp",
|
||||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||||
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \
|
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||||
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
||||||
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
Baichuan choices: [\"mlp\", \"self_attn\"], \
|
||||||
|
Qwen choices: [\"mlp\", \"attn\"], \
|
||||||
|
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
|
||||||
)
|
)
|
||||||
lora_rank: Optional[int] = field(
|
lora_rank: Optional[int] = field(
|
||||||
default=8,
|
default=8,
|
||||||
@@ -45,11 +49,25 @@ class FinetuningArguments:
|
|||||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
||||||
)
|
)
|
||||||
lora_target: Optional[str] = field(
|
lora_target: Optional[str] = field(
|
||||||
default="q_proj,v_proj",
|
default=None,
|
||||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||||
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||||
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||||
|
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
||||||
|
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
|
||||||
|
)
|
||||||
|
resume_lora_training: Optional[bool] = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||||
|
)
|
||||||
|
ppo_score_norm: Optional[bool] = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use score normalization in PPO Training."}
|
||||||
|
)
|
||||||
|
dpo_beta: Optional[float] = field(
|
||||||
|
default=0.1,
|
||||||
|
metadata={"help": "The beta parameter for the DPO loss."}
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -63,17 +81,17 @@ class FinetuningArguments:
|
|||||||
|
|
||||||
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
||||||
|
|
||||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
|
||||||
|
|
||||||
def save_to_json(self, json_path: str):
|
def save_to_json(self, json_path: str):
|
||||||
"""Saves the content of this instance in JSON format inside `json_path`."""
|
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
f.write(json_string)
|
f.write(json_string)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_from_json(cls, json_path: str):
|
def load_from_json(cls, json_path: str):
|
||||||
"""Creates an instance from the content of `json_path`."""
|
r"""Creates an instance from the content of `json_path`."""
|
||||||
with open(json_path, "r", encoding="utf-8") as f:
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
return cls(**json.loads(text))
|
return cls(**json.loads(text))
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ from dataclasses import dataclass, field
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GeneralArguments:
|
class GeneralArguments:
|
||||||
"""
|
r"""
|
||||||
Arguments pertaining to which stage we are going to perform.
|
Arguments pertaining to which stage we are going to perform.
|
||||||
"""
|
"""
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
|
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||||
default="sft",
|
default="sft",
|
||||||
metadata={"help": "Which stage will be performed in training."}
|
metadata={"help": "Which stage will be performed in training."}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GeneratingArguments:
|
class GeneratingArguments:
|
||||||
"""
|
r"""
|
||||||
Arguments pertaining to specify the decoding parameters.
|
Arguments pertaining to specify the decoding parameters.
|
||||||
"""
|
"""
|
||||||
do_sample: Optional[bool] = field(
|
do_sample: Optional[bool] = field(
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArguments:
|
class ModelArguments:
|
||||||
"""
|
r"""
|
||||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||||
"""
|
"""
|
||||||
model_name_or_path: str = field(
|
model_name_or_path: str = field(
|
||||||
@@ -43,9 +43,9 @@ class ModelArguments:
|
|||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether to use double quantization in int4 training or not."}
|
metadata={"help": "Whether to use double quantization in int4 training or not."}
|
||||||
)
|
)
|
||||||
compute_dtype: Optional[torch.dtype] = field(
|
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
metadata={"help": "Adopt scaled rotary positional embeddings."}
|
||||||
)
|
)
|
||||||
checkpoint_dir: Optional[str] = field(
|
checkpoint_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -55,18 +55,33 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||||
)
|
)
|
||||||
resume_lora_training: Optional[bool] = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
|
||||||
)
|
|
||||||
plot_loss: Optional[bool] = field(
|
plot_loss: Optional[bool] = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||||
)
|
)
|
||||||
|
hf_auth_token: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||||
|
)
|
||||||
|
compute_dtype: Optional[torch.dtype] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
||||||
|
)
|
||||||
|
model_max_length: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Used in rope scaling. Do not specify this argument manually."}
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
if self.compute_dtype is not None or self.model_max_length is not None:
|
||||||
|
raise ValueError("These arguments cannot be specified.")
|
||||||
|
|
||||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||||
|
|
||||||
if self.quantization_bit is not None:
|
if self.quantization_bit is not None:
|
||||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
|
||||||
|
if self.use_auth_token == True and self.hf_auth_token is not None:
|
||||||
|
from huggingface_hub.hf_api import HfFolder # lazy load
|
||||||
|
HfFolder.save_token(self.hf_auth_token)
|
||||||
|
|||||||
@@ -1,5 +1 @@
|
|||||||
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
|
from llmtuner.tuner.tune import export_model, run_exp
|
||||||
from llmtuner.tuner.pt import run_pt
|
|
||||||
from llmtuner.tuner.sft import run_sft
|
|
||||||
from llmtuner.tuner.rm import run_rm
|
|
||||||
from llmtuner.tuner.ppo import run_ppo
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ def init_adapter(
|
|||||||
if finetuning_args.finetuning_type == "none" and is_trainable:
|
if finetuning_args.finetuning_type == "none" and is_trainable:
|
||||||
raise ValueError("You cannot use finetuning_type=none while training.")
|
raise ValueError("You cannot use finetuning_type=none while training.")
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full":
|
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info("Fine-tuning method: Full")
|
||||||
model = model.float()
|
model = model.float()
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ def init_adapter(
|
|||||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
||||||
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||||
|
|
||||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
|
||||||
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||||
else:
|
else:
|
||||||
checkpoints_to_merge = model_args.checkpoint_dir
|
checkpoints_to_merge = model_args.checkpoint_dir
|
||||||
|
|||||||
@@ -1,18 +1,21 @@
|
|||||||
import os
|
import os
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BitsAndBytesConfig
|
BitsAndBytesConfig,
|
||||||
|
PretrainedConfig,
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizerBase
|
||||||
)
|
)
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from llmtuner.extras.logging import reset_logging, get_logger
|
from llmtuner.extras.logging import reset_logging, get_logger
|
||||||
@@ -22,6 +25,7 @@ from llmtuner.hparams import FinetuningArguments
|
|||||||
from llmtuner.tuner.core.adapter import init_adapter
|
from llmtuner.tuner.core.adapter import init_adapter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
from llmtuner.hparams import ModelArguments
|
from llmtuner.hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
@@ -32,7 +36,7 @@ check_min_version("4.29.1")
|
|||||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||||
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||||
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
|
require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
@@ -40,7 +44,7 @@ def load_model_and_tokenizer(
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
is_trainable: Optional[bool] = False,
|
is_trainable: Optional[bool] = False,
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
|
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
|
||||||
r"""
|
r"""
|
||||||
Loads pretrained model and tokenizer.
|
Loads pretrained model and tokenizer.
|
||||||
|
|
||||||
@@ -50,9 +54,6 @@ def load_model_and_tokenizer(
|
|||||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||||
|
|
||||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
|
||||||
"RM and PPO training can only be performed with the LoRA method."
|
|
||||||
|
|
||||||
config_kwargs = {
|
config_kwargs = {
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"cache_dir": model_args.cache_dir,
|
"cache_dir": model_args.cache_dir,
|
||||||
@@ -66,21 +67,67 @@ def load_model_and_tokenizer(
|
|||||||
padding_side=model_args.padding_side,
|
padding_side=model_args.padding_side,
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
|
|
||||||
tokenizer.pad_token_id = 0 # set as the <unk> token
|
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None:
|
||||||
is_mergeable = True
|
model_to_load = model_args.checkpoint_dir[0]
|
||||||
|
else:
|
||||||
|
model_to_load = model_args.model_name_or_path
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||||
|
|
||||||
|
if hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
|
||||||
|
if model_args.compute_dtype == torch.bfloat16:
|
||||||
|
setattr(config, "bf16", True)
|
||||||
|
else:
|
||||||
|
setattr(config, "fp16", True)
|
||||||
|
|
||||||
|
# Set RoPE scaling
|
||||||
|
if model_args.rope_scaling is not None:
|
||||||
|
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
|
||||||
|
if is_trainable:
|
||||||
|
logger.warning("Qwen model does not support RoPE scaling in training.")
|
||||||
|
else:
|
||||||
|
setattr(config, "use_dynamic_ntk", True)
|
||||||
|
setattr(config, "use_logn_attn", True)
|
||||||
|
logger.info("Using dynamic NTK scaling.")
|
||||||
|
|
||||||
|
elif hasattr(config, "rope_scaling"): # for LLaMA models
|
||||||
|
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
|
||||||
|
|
||||||
|
if is_trainable:
|
||||||
|
if model_args.rope_scaling == "dynamic":
|
||||||
|
logger.warning(
|
||||||
|
"Dynamic NTK may not work well with fine-tuning. "
|
||||||
|
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||||
|
)
|
||||||
|
|
||||||
|
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||||
|
if current_max_length and model_args.model_max_length > current_max_length:
|
||||||
|
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||||
|
else:
|
||||||
|
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||||
|
scaling_factor = 1.0
|
||||||
|
else:
|
||||||
|
scaling_factor = 2.0
|
||||||
|
|
||||||
|
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||||
|
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
||||||
|
model_args.rope_scaling, scaling_factor
|
||||||
|
))
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning("Current model does not support RoPE scaling.")
|
||||||
|
|
||||||
# Quantization configurations (using bitsandbytes library).
|
# Quantization configurations (using bitsandbytes library).
|
||||||
|
is_mergeable = True
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||||
|
|
||||||
if model_args.quantization_bit == 8:
|
if model_args.quantization_bit == 8:
|
||||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||||
config_kwargs["load_in_8bit"] = True
|
config_kwargs["load_in_8bit"] = True
|
||||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
load_in_8bit=True,
|
|
||||||
llm_int8_threshold=6.0
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_args.quantization_bit == 4:
|
elif model_args.quantization_bit == 4:
|
||||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||||
@@ -93,28 +140,26 @@ def load_model_and_tokenizer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
if (
|
# Load and prepare pre-trained models (without valuehead).
|
||||||
model_args.quantization_bit is not None
|
|
||||||
or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled())
|
|
||||||
):
|
|
||||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
|
||||||
model_to_load = model_args.checkpoint_dir[0]
|
|
||||||
else:
|
|
||||||
model_to_load = model_args.model_name_or_path
|
|
||||||
|
|
||||||
# Load and prepare pretrained models (without valuehead).
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
config=config,
|
config=config,
|
||||||
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
|
torch_dtype=model_args.compute_dtype,
|
||||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Disable custom generate method (for Qwen)
|
||||||
|
if "GenerationMixin" not in str(model.generate.__func__):
|
||||||
|
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||||
|
|
||||||
|
# Fix LM head (for ChatGLM2)
|
||||||
|
if not hasattr(model, "lm_head") and hasattr(model, "transformer"):
|
||||||
|
setattr(model, "lm_head", model.transformer.output_layer)
|
||||||
|
|
||||||
# Register auto class to save the custom code files.
|
# Register auto class to save the custom code files.
|
||||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||||
config.__class__.register_for_auto_class()
|
config.__class__.register_for_auto_class()
|
||||||
@@ -127,10 +172,10 @@ def load_model_and_tokenizer(
|
|||||||
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
||||||
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||||
|
|
||||||
if stage == "rm" or stage == "ppo": # add value head
|
# Prepare model with valuehead for RLHF
|
||||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
if stage == "rm" or stage == "ppo":
|
||||||
|
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||||
reset_logging()
|
reset_logging()
|
||||||
|
|
||||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
||||||
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
||||||
@@ -140,15 +185,15 @@ def load_model_and_tokenizer(
|
|||||||
})
|
})
|
||||||
|
|
||||||
if stage == "ppo": # load reward model
|
if stage == "ppo": # load reward model
|
||||||
assert is_trainable, "PPO stage cannot be performed at evaluation."
|
|
||||||
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
|
|
||||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
||||||
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
||||||
|
|
||||||
|
# Prepare model for inference
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False) # fix all model params
|
model.requires_grad_(False) # fix all model params
|
||||||
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
|
||||||
|
model = model.to(infer_dtype) if model_args.quantization_bit is None else model
|
||||||
|
|
||||||
trainable_params, all_param = count_parameters(model)
|
trainable_params, all_param = count_parameters(model)
|
||||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import datasets
|
|||||||
import transformers
|
import transformers
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||||
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.hparams import (
|
from llmtuner.hparams import (
|
||||||
@@ -19,7 +20,7 @@ from llmtuner.hparams import (
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
|
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||||
if args is not None:
|
if args is not None:
|
||||||
return parser.parse_dict(args)
|
return parser.parse_dict(args)
|
||||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||||
@@ -32,26 +33,53 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None)
|
|||||||
|
|
||||||
def parse_train_args(
|
def parse_train_args(
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
) -> Tuple[
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
Seq2SeqTrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
GeneralArguments
|
||||||
|
]:
|
||||||
parser = HfArgumentParser((
|
parser = HfArgumentParser((
|
||||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
Seq2SeqTrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
GeneralArguments
|
||||||
))
|
))
|
||||||
return _parse_args(parser, args)
|
return _parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
def parse_infer_args(
|
def parse_infer_args(
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
) -> Tuple[
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments
|
||||||
|
]:
|
||||||
parser = HfArgumentParser((
|
parser = HfArgumentParser((
|
||||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments
|
||||||
))
|
))
|
||||||
return _parse_args(parser, args)
|
return _parse_args(parser, args)
|
||||||
|
|
||||||
|
|
||||||
def get_train_args(
|
def get_train_args(
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
) -> Tuple[
|
||||||
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args)
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
Seq2SeqTrainingArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments,
|
||||||
|
GeneralArguments
|
||||||
|
]:
|
||||||
|
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
if training_args.should_log:
|
if training_args.should_log:
|
||||||
@@ -67,40 +95,61 @@ def get_train_args(
|
|||||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||||
data_args.init_for_training()
|
data_args.init_for_training()
|
||||||
|
|
||||||
assert general_args.stage == "sft" or (not training_args.predict_with_generate), \
|
if general_args.stage != "sft" and training_args.predict_with_generate:
|
||||||
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
|
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||||
|
|
||||||
assert not (training_args.do_train and training_args.predict_with_generate), \
|
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||||
"`predict_with_generate` cannot be set as True while training."
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||||
|
|
||||||
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
|
if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
||||||
"Please enable `predict_with_generate` to save model predictions."
|
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
||||||
|
|
||||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
||||||
"Quantization is only compatible with the LoRA method."
|
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||||
|
|
||||||
assert not (training_args.max_steps == -1 and data_args.streaming), \
|
if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
|
||||||
"Please specify `max_steps` in streaming mode."
|
raise ValueError("PPO and DPO stages can only be performed at training.")
|
||||||
|
|
||||||
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
|
if general_args.stage == "ppo" and model_args.reward_model is None:
|
||||||
"Streaming mode does not support evaluation currently."
|
raise ValueError("Reward model is necessary for PPO training.")
|
||||||
|
|
||||||
assert not (general_args.stage == "ppo" and data_args.streaming), \
|
if general_args.stage == "ppo" and data_args.streaming:
|
||||||
"Streaming mode does not suppport PPO training currently."
|
raise ValueError("Streaming mode does not suppport PPO training currently.")
|
||||||
|
|
||||||
|
if training_args.max_steps == -1 and data_args.streaming:
|
||||||
|
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||||
|
|
||||||
|
if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
|
||||||
|
raise ValueError("Streaming mode should have an integer val size.")
|
||||||
|
|
||||||
|
if training_args.do_train and training_args.predict_with_generate:
|
||||||
|
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||||
|
|
||||||
|
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
||||||
|
raise ValueError("Please specify `lora_target` in LoRA training.")
|
||||||
|
|
||||||
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||||
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
if finetuning_args.finetuning_type != "lora":
|
if finetuning_args.finetuning_type != "lora":
|
||||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
if len(model_args.checkpoint_dir) != 1:
|
||||||
else:
|
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||||
"Quantized model only accepts a single checkpoint."
|
raise ValueError("Quantized model only accepts a single checkpoint.")
|
||||||
|
|
||||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||||
|
|
||||||
if training_args.do_train and (not training_args.fp16):
|
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||||
logger.warning("We recommend enable fp16 mixed precision training.")
|
logger.warning("We recommend enable mixed precision training.")
|
||||||
|
|
||||||
|
# postprocess data_args
|
||||||
|
if data_args.max_samples is not None and data_args.streaming:
|
||||||
|
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
||||||
|
data_args.max_samples = None
|
||||||
|
|
||||||
|
# postprocess training_args
|
||||||
if (
|
if (
|
||||||
training_args.local_rank != -1
|
training_args.local_rank != -1
|
||||||
and training_args.ddp_find_unused_parameters is None
|
and training_args.ddp_find_unused_parameters is None
|
||||||
@@ -109,50 +158,66 @@ def get_train_args(
|
|||||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||||
training_args.ddp_find_unused_parameters = False
|
training_args.ddp_find_unused_parameters = False
|
||||||
|
|
||||||
if data_args.max_samples is not None and data_args.streaming:
|
if training_args.optim == "adamw_hf":
|
||||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
training_args.optim = "adamw_torch" # suppress warning
|
||||||
data_args.max_samples = None
|
|
||||||
|
|
||||||
if data_args.dev_ratio > 1e-6 and data_args.streaming:
|
if (
|
||||||
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
|
training_args.resume_from_checkpoint is None
|
||||||
data_args.dev_ratio = 0
|
and training_args.do_train
|
||||||
|
and os.path.isdir(training_args.output_dir)
|
||||||
|
and not training_args.overwrite_output_dir
|
||||||
|
):
|
||||||
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||||
|
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||||
|
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
|
||||||
|
|
||||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
if last_checkpoint is not None:
|
||||||
|
training_args.resume_from_checkpoint = last_checkpoint
|
||||||
|
logger.info(
|
||||||
|
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
|
||||||
|
)
|
||||||
|
|
||||||
if model_args.quantization_bit is not None:
|
# postprocess model_args
|
||||||
if training_args.fp16:
|
if training_args.bf16:
|
||||||
model_args.compute_dtype = torch.float16
|
if not torch.cuda.is_bf16_supported():
|
||||||
elif training_args.bf16:
|
raise ValueError("Current device does not support bf16 training.")
|
||||||
model_args.compute_dtype = torch.bfloat16
|
model_args.compute_dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
model_args.compute_dtype = torch.float32
|
model_args.compute_dtype = torch.float16
|
||||||
|
|
||||||
|
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format(
|
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
||||||
training_args.local_rank, training_args.device, training_args.n_gpu,
|
training_args.local_rank, training_args.device, training_args.n_gpu,
|
||||||
bool(training_args.local_rank != -1), training_args.fp16
|
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
|
||||||
))
|
))
|
||||||
logger.info(f"Training/evaluation parameters {training_args}")
|
logger.info(f"Training/evaluation parameters {training_args}")
|
||||||
|
|
||||||
# Set seed before initializing model.
|
# Set seed before initializing model.
|
||||||
transformers.set_seed(training_args.seed)
|
transformers.set_seed(training_args.seed)
|
||||||
|
|
||||||
return model_args, data_args, training_args, finetuning_args, general_args
|
return model_args, data_args, training_args, finetuning_args, generating_args, general_args
|
||||||
|
|
||||||
|
|
||||||
def get_infer_args(
|
def get_infer_args(
|
||||||
args: Optional[Dict[str, Any]] = None
|
args: Optional[Dict[str, Any]] = None
|
||||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
) -> Tuple[
|
||||||
|
ModelArguments,
|
||||||
|
DataArguments,
|
||||||
|
FinetuningArguments,
|
||||||
|
GeneratingArguments
|
||||||
|
]:
|
||||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||||
|
|
||||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||||
"Quantization is only compatible with the LoRA method."
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
if finetuning_args.finetuning_type != "lora":
|
if finetuning_args.finetuning_type != "lora":
|
||||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
if len(model_args.checkpoint_dir) != 1:
|
||||||
else:
|
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||||
"Quantized model only accepts a single checkpoint."
|
raise ValueError("Quantized model only accepts a single checkpoint.")
|
||||||
|
|
||||||
return model_args, data_args, finetuning_args, generating_args
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
|||||||
@@ -13,26 +13,25 @@ from llmtuner.extras.logging import get_logger
|
|||||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PeftTrainer(Seq2SeqTrainer):
|
class PeftModelMixin:
|
||||||
r"""
|
r"""
|
||||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
def __init__(self) -> None: # for type checking
|
||||||
super().__init__(**kwargs)
|
self.model: PreTrainedModel = None
|
||||||
self.finetuning_args = finetuning_args
|
self.tokenizer: "PreTrainedTokenizer" = None
|
||||||
self._remove_log()
|
self.args: "Seq2SeqTrainingArguments" = None
|
||||||
|
self.finetuning_args: "FinetuningArguments" = None
|
||||||
def _remove_log(self):
|
self.state: "TrainerState" = None
|
||||||
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
raise AssertionError("Mixin should not be initialized.")
|
||||||
logger.warning("Previous log file in this folder will be deleted.")
|
|
||||||
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
@@ -47,7 +46,6 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||||
|
|
||||||
model = unwrap_model(self.model)
|
model = unwrap_model(self.model)
|
||||||
|
|
||||||
if isinstance(model, PreTrainedModelWrapper):
|
if isinstance(model, PreTrainedModelWrapper):
|
||||||
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
|
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
|
||||||
model_state_dict = state_dict or model.state_dict()
|
model_state_dict = state_dict or model.state_dict()
|
||||||
@@ -68,7 +66,10 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
|
|
||||||
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
|
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
|
||||||
self.tokenizer.save_pretrained(output_dir)
|
try:
|
||||||
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
except:
|
||||||
|
logger.warning("Cannot save tokenizer, copy the files manually.")
|
||||||
|
|
||||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||||
f.write(self.args.to_json_string() + "\n")
|
f.write(self.args.to_json_string() + "\n")
|
||||||
@@ -94,3 +95,13 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||||||
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
||||||
else: # freeze/full-tuning
|
else: # freeze/full-tuning
|
||||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
|
||||||
|
r"""
|
||||||
|
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||||
|
Seq2SeqTrainer.__init__(self, **kwargs)
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
|
|||||||
1
src/llmtuner/tuner/dpo/__init__.py
Normal file
1
src/llmtuner/tuner/dpo/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from llmtuner.tuner.dpo.workflow import run_dpo
|
||||||
51
src/llmtuner/tuner/dpo/collator.py
Normal file
51
src/llmtuner/tuner/dpo/collator.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Sequence, Tuple
|
||||||
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||||
|
r"""
|
||||||
|
Data collator for pairwise data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
|
||||||
|
padded_labels = []
|
||||||
|
for feature, (prompt_len, answer_len) in zip(batch, positions):
|
||||||
|
if self.tokenizer.padding_side == "left":
|
||||||
|
start, end = feature.size(0) - answer_len, feature.size(0)
|
||||||
|
else:
|
||||||
|
start, end = prompt_len, answer_len
|
||||||
|
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
||||||
|
padded_tensor[start:end] = feature[start:end]
|
||||||
|
padded_labels.append(padded_tensor)
|
||||||
|
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
||||||
|
|
||||||
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||||
|
r"""
|
||||||
|
Pads batched data to the longest sequence in the batch.
|
||||||
|
|
||||||
|
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||||
|
the last n examples represent rejected examples.
|
||||||
|
"""
|
||||||
|
concatenated_features = []
|
||||||
|
label_positions = []
|
||||||
|
for key in ("chosen_ids", "rejected_ids"):
|
||||||
|
for feature in features:
|
||||||
|
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
||||||
|
concatenated_features.append({
|
||||||
|
"input_ids": feature["prompt_ids"] + feature[key],
|
||||||
|
"attention_mask": [1] * (prompt_len + answer_len)
|
||||||
|
})
|
||||||
|
label_positions.append((prompt_len, answer_len))
|
||||||
|
|
||||||
|
batch = self.tokenizer.pad(
|
||||||
|
concatenated_features,
|
||||||
|
padding=self.padding,
|
||||||
|
max_length=self.max_length,
|
||||||
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
|
return_tensors=self.return_tensors,
|
||||||
|
)
|
||||||
|
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||||
|
return batch
|
||||||
77
src/llmtuner/tuner/dpo/trainer.py
Normal file
77
src/llmtuner/tuner/dpo/trainer.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
from collections import defaultdict
|
||||||
|
from peft import PeftModel
|
||||||
|
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||||
|
from transformers import BatchEncoding, Trainer
|
||||||
|
from trl import DPOTrainer
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
from llmtuner.tuner.core.trainer import PeftModelMixin
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
|
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.finetuning_args = finetuning_args
|
||||||
|
self.generating_args = generating_args
|
||||||
|
self.ref_model = ref_model
|
||||||
|
self.use_dpo_data_collator = True # hack to avoid warning
|
||||||
|
self.label_pad_token_id = IGNORE_INDEX
|
||||||
|
self.padding_value = 0
|
||||||
|
self.beta = finetuning_args.dpo_beta
|
||||||
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
|
||||||
|
Trainer.__init__(self, **kwargs)
|
||||||
|
if not hasattr(self, "accelerator"):
|
||||||
|
raise AttributeError("Please update `transformers`.")
|
||||||
|
|
||||||
|
if ref_model is not None:
|
||||||
|
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||||
|
|
||||||
|
def concatenated_forward(
|
||||||
|
self,
|
||||||
|
model: Optional[torch.nn.Module] = None,
|
||||||
|
batch: Optional[Dict[str, torch.Tensor]] = None
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||||
|
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||||
|
|
||||||
|
if not torch.is_grad_enabled():
|
||||||
|
unwrapped_model.gradient_checkpointing_disable()
|
||||||
|
|
||||||
|
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
|
||||||
|
with unwrapped_model.disable_adapter():
|
||||||
|
all_logits = self.model(
|
||||||
|
input_ids=batch_copied["input_ids"],
|
||||||
|
attention_mask=batch_copied["attention_mask"],
|
||||||
|
return_dict=True
|
||||||
|
).logits.to(torch.float32)
|
||||||
|
else:
|
||||||
|
all_logits = model(
|
||||||
|
input_ids=batch_copied["input_ids"],
|
||||||
|
attention_mask=batch_copied["attention_mask"],
|
||||||
|
return_dict=True
|
||||||
|
).logits.to(torch.float32)
|
||||||
|
|
||||||
|
if not torch.is_grad_enabled():
|
||||||
|
unwrapped_model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
all_logps = self._get_batch_logps(
|
||||||
|
all_logits,
|
||||||
|
batch["labels"],
|
||||||
|
average_log_prob=False
|
||||||
|
)
|
||||||
|
batch_size = batch["input_ids"].size(0) // 2
|
||||||
|
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||||
|
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||||
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
||||||
59
src/llmtuner/tuner/dpo/workflow.py
Normal file
59
src/llmtuner/tuner/dpo/workflow.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from peft import PeftModel
|
||||||
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
|
|
||||||
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
|
from llmtuner.extras.ploting import plot_loss
|
||||||
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
|
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
||||||
|
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
|
def run_dpo(
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
|
):
|
||||||
|
dataset = get_dataset(model_args, data_args)
|
||||||
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||||
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||||
|
data_collator = DPODataCollatorWithPadding(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||||
|
ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None
|
||||||
|
|
||||||
|
# Initialize our Trainer
|
||||||
|
trainer = DPOPeftTrainer(
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
|
generating_args=generating_args,
|
||||||
|
ref_model=ref_model,
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=data_collator,
|
||||||
|
callbacks=callbacks,
|
||||||
|
**split_dataset(dataset, data_args, training_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if training_args.do_train:
|
||||||
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
trainer.save_model()
|
||||||
|
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||||
|
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||||
@@ -2,24 +2,23 @@ import os
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import TrainerState, TrainerControl
|
from transformers import TrainerState, TrainerControl
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
|
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import LengthSampler
|
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
|
|
||||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
from llmtuner.tuner.ppo.utils import replace_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -34,17 +33,19 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
self,
|
self,
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: List["LogCallback"],
|
callbacks: List["LogCallback"],
|
||||||
|
compute_dtype: torch.dtype,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
PPOTrainer.__init__(self, **kwargs)
|
PPOTrainer.__init__(self, **kwargs)
|
||||||
self.args = training_args
|
self.args = training_args
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
|
self.generating_args = generating_args
|
||||||
self.log_callback = callbacks[0]
|
self.log_callback = callbacks[0]
|
||||||
|
self.compute_dtype = compute_dtype
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
|
|
||||||
self._remove_log()
|
|
||||||
|
|
||||||
def ppo_train(self, max_target_length: int) -> None:
|
def ppo_train(self, max_target_length: int) -> None:
|
||||||
r"""
|
r"""
|
||||||
@@ -74,16 +75,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = {
|
gen_kwargs = self.generating_args.to_dict()
|
||||||
"top_k": 0.0,
|
gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids))
|
||||||
"top_p": 1.0,
|
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
|
||||||
"do_sample": True,
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
"pad_token_id": self.tokenizer.pad_token_id,
|
|
||||||
"eos_token_id": self.tokenizer.eos_token_id,
|
|
||||||
"logits_processor": get_logits_processor()
|
|
||||||
}
|
|
||||||
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||||
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
|
|
||||||
dataiter = iter(self.dataloader)
|
dataiter = iter(self.dataloader)
|
||||||
steps_trained = 0
|
steps_trained = 0
|
||||||
@@ -91,51 +89,38 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
reward_meter = AverageMeter()
|
reward_meter = AverageMeter()
|
||||||
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
||||||
|
|
||||||
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
|
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
|
||||||
batch = next(dataiter)
|
batch = next(dataiter)
|
||||||
steps_trained += 1
|
steps_trained += 1
|
||||||
|
|
||||||
|
# Cast to inference mode
|
||||||
unwrapped_model.gradient_checkpointing_disable()
|
unwrapped_model.gradient_checkpointing_disable()
|
||||||
unwrapped_model.config.use_cache = True
|
unwrapped_model.config.use_cache = True
|
||||||
|
|
||||||
# Get responses
|
# Get inputs
|
||||||
query_tensors = batch["input_ids"]
|
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
|
||||||
response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs)
|
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
||||||
|
|
||||||
queries, responses = [], []
|
# Cast to training mode
|
||||||
for i in range(len(query_tensors)):
|
|
||||||
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
|
||||||
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
|
||||||
queries.append(query_tensors[i, query_length:]) # remove padding from left
|
|
||||||
responses.append(response_tensors[i, :response_length]) # remove padding from right
|
|
||||||
|
|
||||||
# Compute rewards
|
|
||||||
replace_model(unwrapped_model, target="reward")
|
|
||||||
with torch.no_grad():
|
|
||||||
_, _, values = self.model(
|
|
||||||
**self.prepare_model_inputs(queries, responses),
|
|
||||||
output_hidden_states=True,
|
|
||||||
return_dict=True
|
|
||||||
)
|
|
||||||
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
|
||||||
replace_model(unwrapped_model, target="default")
|
|
||||||
|
|
||||||
# Run PPO step
|
|
||||||
unwrapped_model.gradient_checkpointing_enable()
|
unwrapped_model.gradient_checkpointing_enable()
|
||||||
unwrapped_model.config.use_cache = False
|
unwrapped_model.config.use_cache = False
|
||||||
stats = self.step(queries, responses, rewards)
|
|
||||||
|
|
||||||
|
# Run PPO step
|
||||||
|
stats = self.step(queries, responses, rewards)
|
||||||
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
||||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||||
|
|
||||||
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
|
self.state.global_step += 1
|
||||||
|
self.log_callback.on_step_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
|
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
|
||||||
logs = dict(
|
logs = dict(
|
||||||
loss=round(loss_meter.avg, 4),
|
loss=round(loss_meter.avg, 4),
|
||||||
reward=round(reward_meter.avg, 4),
|
reward=round(reward_meter.avg, 4),
|
||||||
learning_rate=stats["ppo/learning_rate"],
|
learning_rate=stats["ppo/learning_rate"],
|
||||||
epoch=round(step / len_dataloader, 2)
|
epoch=round(step / len_dataloader, 2)
|
||||||
)
|
)
|
||||||
print(logs)
|
tqdm.write(str(logs))
|
||||||
logs["step"] = step
|
logs["step"] = step
|
||||||
self.state.log_history.append(logs)
|
self.state.log_history.append(logs)
|
||||||
self.log_callback.on_log(self.args, self.state, self.control)
|
self.log_callback.on_log(self.args, self.state, self.control)
|
||||||
@@ -152,38 +137,124 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
dataiter = iter(self.dataloader)
|
dataiter = iter(self.dataloader)
|
||||||
steps_trained = 0
|
steps_trained = 0
|
||||||
|
|
||||||
|
self.log_callback.on_train_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def get_inputs(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, torch.Tensor],
|
batch: Dict[str, torch.Tensor],
|
||||||
length_sampler: Optional[Callable] = None,
|
length_sampler: Optional[Callable] = None,
|
||||||
return_prompt: Optional[bool] = True,
|
|
||||||
**generation_kwargs
|
**generation_kwargs
|
||||||
) -> torch.Tensor:
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Generates model's responses given queries.
|
Generates model's responses given queries.
|
||||||
|
|
||||||
Subclass and override to inject custom behavior.
|
|
||||||
"""
|
"""
|
||||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
|
|
||||||
|
|
||||||
if length_sampler is not None:
|
if length_sampler is not None:
|
||||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
generation_kwargs["max_new_tokens"] = length_sampler()
|
||||||
|
|
||||||
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
|
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
|
||||||
response = unwrapped_model.generate(**inputs, **generation_kwargs)
|
|
||||||
|
|
||||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
||||||
if unwrapped_model.pretrained_model.generation_config._from_model_config:
|
if unwrapped_model.pretrained_model.generation_config._from_model_config:
|
||||||
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
||||||
|
|
||||||
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
|
queries, responses = [], []
|
||||||
|
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||||
|
for i in range(len(query)):
|
||||||
|
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||||
|
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||||
|
queries.append(query[i, query_length:]) # remove padding from left
|
||||||
|
responses.append(response[i, :response_length]) # remove padding from right
|
||||||
|
|
||||||
if not return_prompt and not self.is_encoder_decoder:
|
return queries, responses
|
||||||
return response[:, inputs["input_ids"].size(1):]
|
|
||||||
return response
|
@torch.no_grad()
|
||||||
|
def get_rewards(
|
||||||
|
self,
|
||||||
|
queries: List[torch.Tensor],
|
||||||
|
responses: List[torch.Tensor],
|
||||||
|
unwrapped_model: "AutoModelForCausalLMWithValueHead"
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
r"""
|
||||||
|
Computes scores using given reward model.
|
||||||
|
"""
|
||||||
|
replace_model(unwrapped_model, target="reward")
|
||||||
|
batch = self.prepare_model_inputs(queries, responses)
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
|
||||||
|
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
|
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
||||||
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
|
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
|
||||||
|
replace_model(unwrapped_model, target="default")
|
||||||
|
return rewards
|
||||||
|
|
||||||
|
@PPODecorators.empty_cuda_cache()
|
||||||
|
def batched_forward_pass(
|
||||||
|
self,
|
||||||
|
model: "AutoModelForCausalLMWithValueHead",
|
||||||
|
queries: torch.Tensor,
|
||||||
|
responses: torch.Tensor,
|
||||||
|
model_inputs: dict,
|
||||||
|
return_logits: Optional[bool] = False
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Calculates model outputs in multiple batches.
|
||||||
|
|
||||||
|
Subclass and override to inject custom behavior.
|
||||||
|
"""
|
||||||
|
bs = len(queries)
|
||||||
|
fbs = self.config.mini_batch_size
|
||||||
|
all_logprobs = []
|
||||||
|
all_logits = []
|
||||||
|
all_masks = []
|
||||||
|
all_values = []
|
||||||
|
|
||||||
|
for i in range(math.ceil(bs / fbs)):
|
||||||
|
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
|
||||||
|
query_batch = queries[i * fbs : (i + 1) * fbs]
|
||||||
|
response_batch = responses[i * fbs : (i + 1) * fbs]
|
||||||
|
input_ids = input_kwargs["input_ids"]
|
||||||
|
attention_mask = input_kwargs["attention_mask"]
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
|
||||||
|
logits, _, values = model(**input_kwargs)
|
||||||
|
|
||||||
|
if values.size(0) != input_ids.size(0): # adapt to chatglm2
|
||||||
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
|
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
||||||
|
masks = torch.zeros_like(attention_mask)
|
||||||
|
masks[:, :-1] = attention_mask[:, 1:]
|
||||||
|
|
||||||
|
for j in range(len(query_batch)):
|
||||||
|
start = len(query_batch[j]) - 1
|
||||||
|
if attention_mask[j, 0] == 0: # offset left padding
|
||||||
|
start += attention_mask[j, :].nonzero()[0]
|
||||||
|
end = start + len(response_batch[j])
|
||||||
|
|
||||||
|
masks[j, :start] = 0
|
||||||
|
masks[j, end:] = 0
|
||||||
|
|
||||||
|
if return_logits:
|
||||||
|
all_logits.append(logits)
|
||||||
|
else:
|
||||||
|
del logits
|
||||||
|
|
||||||
|
all_values.append(values)
|
||||||
|
all_logprobs.append(logprobs)
|
||||||
|
all_masks.append(masks)
|
||||||
|
|
||||||
|
return (
|
||||||
|
torch.cat(all_logprobs),
|
||||||
|
torch.cat(all_logits)[:, :-1] if return_logits else None,
|
||||||
|
torch.cat(all_values)[:, :-1],
|
||||||
|
torch.cat(all_masks)[:, :-1],
|
||||||
|
)
|
||||||
|
|
||||||
def save_model(self, output_dir: Optional[str] = None) -> None:
|
def save_model(self, output_dir: Optional[str] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -1,7 +1,4 @@
|
|||||||
import torch
|
from typing import TYPE_CHECKING, Literal
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
@@ -18,22 +15,3 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
|
|||||||
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
||||||
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def cast_layernorm_dtype(
|
|
||||||
model: "AutoModelForCausalLMWithValueHead",
|
|
||||||
layer_norm_names: List[str] = LAYERNORM_NAMES,
|
|
||||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
|
||||||
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
|
|
||||||
|
|
||||||
layer_norm_state_dict = {}
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
|
||||||
if layer_norm_params is not None:
|
|
||||||
param.data = layer_norm_params[name] # restore float32 weights
|
|
||||||
else:
|
|
||||||
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
|
||||||
param.data = param.data.to(torch.float16)
|
|
||||||
|
|
||||||
return model, layer_norm_state_dict
|
|
||||||
|
|||||||
@@ -1,23 +1,21 @@
|
|||||||
# Inspired by:
|
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
||||||
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from trl import PPOConfig
|
from trl import PPOConfig
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from typing import Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
def run_ppo(
|
def run_ppo(
|
||||||
@@ -25,7 +23,8 @@ def run_ppo(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
generating_args: "GeneratingArguments",
|
||||||
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||||
@@ -39,24 +38,35 @@ def run_ppo(
|
|||||||
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
|
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
|
||||||
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
||||||
ppo_epochs=1,
|
ppo_epochs=1,
|
||||||
max_grad_norm=training_args.max_grad_norm
|
max_grad_norm=training_args.max_grad_norm,
|
||||||
|
seed=training_args.seed,
|
||||||
|
optimize_cuda_cache=True
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
|
if finetuning_args.ppo_score_norm:
|
||||||
total_train_batch_size = \
|
require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git")
|
||||||
|
ppo_config.use_score_scaling = True
|
||||||
|
ppo_config.use_score_norm = True
|
||||||
|
|
||||||
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||||
|
total_train_batch_size = (
|
||||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||||
|
)
|
||||||
|
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
training_args.lr_scheduler_type,
|
training_args.lr_scheduler_type,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
num_warmup_steps=training_args.warmup_steps,
|
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
|
||||||
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
|
num_training_steps=num_training_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
ppo_trainer = PPOPeftTrainer(
|
ppo_trainer = PPOPeftTrainer(
|
||||||
training_args=training_args,
|
training_args=training_args,
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
|
generating_args=generating_args,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
compute_dtype=model_args.compute_dtype,
|
||||||
config=ppo_config,
|
config=ppo_config,
|
||||||
model=model,
|
model=model,
|
||||||
ref_model=None,
|
ref_model=None,
|
||||||
@@ -67,8 +77,10 @@ def run_ppo(
|
|||||||
lr_scheduler=lr_scheduler
|
lr_scheduler=lr_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
# Training
|
||||||
ppo_trainer.save_model()
|
if training_args.do_train:
|
||||||
ppo_trainer.save_state() # must be after save_model
|
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
||||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
ppo_trainer.save_model()
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||||
|
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||||
|
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||||
|
|||||||
@@ -2,11 +2,9 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForLanguageModeling
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||||
@@ -21,15 +19,12 @@ def run_pt(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||||
tokenizer=tokenizer,
|
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = PeftTrainer(
|
trainer = PeftTrainer(
|
||||||
@@ -39,12 +34,12 @@ def run_pt(
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
**split_dataset(dataset, data_args, training_args)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train()
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
@@ -61,6 +56,5 @@ def run_pt(
|
|||||||
perplexity = float("inf")
|
perplexity = float("inf")
|
||||||
|
|
||||||
metrics["perplexity"] = perplexity
|
metrics["perplexity"] = perplexity
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Sequence
|
from typing import Any, Dict, Sequence
|
||||||
from transformers import DataCollatorWithPadding
|
from transformers import DataCollatorWithPadding
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||||
r"""
|
r"""
|
||||||
Data collator for pairwise data.
|
Data collator for pairwise data.
|
||||||
@@ -16,7 +18,10 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
|||||||
the last n examples represent rejected examples.
|
the last n examples represent rejected examples.
|
||||||
"""
|
"""
|
||||||
features = [
|
features = [
|
||||||
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])}
|
{
|
||||||
for key in ("accept_ids", "reject_ids") for feature in features
|
"input_ids": feature["prompt_ids"] + feature[key],
|
||||||
|
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
|
||||||
|
}
|
||||||
|
for key in ("chosen_ids", "rejected_ids") for feature in features
|
||||||
]
|
]
|
||||||
return super().__call__(features)
|
return super().__call__(features)
|
||||||
|
|||||||
@@ -42,6 +42,8 @@ class PairwisePeftTrainer(PeftTrainer):
|
|||||||
"""
|
"""
|
||||||
batch_size = inputs["input_ids"].size(0) // 2
|
batch_size = inputs["input_ids"].size(0) // 2
|
||||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||||
|
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
|
||||||
|
values = torch.transpose(values, 0, 1)
|
||||||
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
||||||
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
||||||
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
|
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||||
@@ -22,7 +21,7 @@ def run_rm(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||||
@@ -40,7 +39,7 @@ def run_rm(
|
|||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=compute_accuracy,
|
compute_metrics=compute_accuracy,
|
||||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
**split_dataset(dataset, data_args, training_args)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class ComputeMetrics:
|
|||||||
Uses the model predictions to compute metrics.
|
Uses the model predictions to compute metrics.
|
||||||
"""
|
"""
|
||||||
preds, labels = eval_preds
|
preds, labels = eval_preds
|
||||||
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||||
|
|
||||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||||
@@ -49,6 +49,5 @@ class ComputeMetrics:
|
|||||||
|
|
||||||
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||||
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||||
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
|
|
||||||
|
|
||||||
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
||||||
|
|||||||
@@ -50,11 +50,12 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||||||
loss, generated_tokens, labels = super().prediction_step(
|
loss, generated_tokens, labels = super().prediction_step(
|
||||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||||
)
|
)
|
||||||
generated_tokens = (
|
if generated_tokens is not None:
|
||||||
generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
|
generated_tokens[:, :max(prompt_len, label_len)] = (
|
||||||
)
|
self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)])
|
||||||
|
)
|
||||||
|
|
||||||
return (loss, generated_tokens, labels)
|
return loss, generated_tokens, labels
|
||||||
|
|
||||||
def _pad_tensors_to_target_len(
|
def _pad_tensors_to_target_len(
|
||||||
self,
|
self,
|
||||||
@@ -72,14 +73,11 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
|||||||
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
||||||
pad_token_id = self.tokenizer.pad_token_id
|
pad_token_id = self.tokenizer.pad_token_id
|
||||||
else:
|
else:
|
||||||
if self.model.config.pad_token_id is not None:
|
raise ValueError("PAD token is required.")
|
||||||
pad_token_id = self.model.config.pad_token_id
|
|
||||||
else:
|
|
||||||
raise ValueError("Pad_token_id must be set in the configuration of the model.")
|
|
||||||
|
|
||||||
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
||||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
||||||
return padded_tensor
|
return padded_tensor.contiguous() # in contiguous memory
|
||||||
|
|
||||||
def save_predictions(
|
def save_predictions(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Optional, List
|
|||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.misc import get_logits_processor
|
from llmtuner.extras.misc import get_logits_processor
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
@@ -14,7 +13,7 @@ from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
|
||||||
|
|
||||||
def run_sft(
|
def run_sft(
|
||||||
@@ -22,7 +21,8 @@ def run_sft(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
generating_args: "GeneratingArguments",
|
||||||
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||||
@@ -47,21 +47,18 @@ def run_sft(
|
|||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
**split_dataset(dataset, data_args, training_args)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = {
|
gen_kwargs = generating_args.to_dict()
|
||||||
"do_sample": True,
|
gen_kwargs["eos_token_id"] = list(set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids))
|
||||||
"top_p": 0.7,
|
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||||
"max_new_tokens": data_args.max_target_length + 1,
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
"temperature": 0.95,
|
|
||||||
"logits_processor": get_logits_processor()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train()
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
|
|||||||
48
src/llmtuner/tuner/tune.py
Normal file
48
src/llmtuner/tuner/tune.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
|
||||||
|
from llmtuner.tuner.pt import run_pt
|
||||||
|
from llmtuner.tuner.sft import run_sft
|
||||||
|
from llmtuner.tuner.rm import run_rm
|
||||||
|
from llmtuner.tuner.ppo import run_ppo
|
||||||
|
from llmtuner.tuner.dpo import run_dpo
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||||
|
model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
|
||||||
|
callbacks = [LogCallback()] if callbacks is None else callbacks
|
||||||
|
|
||||||
|
if general_args.stage == "pt":
|
||||||
|
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
elif general_args.stage == "sft":
|
||||||
|
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||||
|
elif general_args.stage == "rm":
|
||||||
|
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
elif general_args.stage == "ppo":
|
||||||
|
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||||
|
elif general_args.stage == "dpo":
|
||||||
|
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown task.")
|
||||||
|
|
||||||
|
|
||||||
|
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
||||||
|
model_args, _, training_args, finetuning_args, _, _ = get_train_args(args)
|
||||||
|
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
|
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
|
||||||
|
try:
|
||||||
|
tokenizer.save_pretrained(training_args.output_dir)
|
||||||
|
except:
|
||||||
|
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_exp()
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
from llmtuner.webui.interface import create_ui, create_web_demo
|
||||||
|
|||||||
@@ -1,22 +1,22 @@
|
|||||||
import os
|
import os
|
||||||
from typing import List, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from llmtuner.chat.stream_chat import ChatModel
|
from llmtuner.chat.stream_chat import ChatModel
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.hparams import GeneratingArguments
|
from llmtuner.hparams import GeneratingArguments
|
||||||
from llmtuner.tuner import get_infer_args
|
|
||||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||||
from llmtuner.webui.locales import ALERTS
|
from llmtuner.webui.locales import ALERTS
|
||||||
|
|
||||||
|
|
||||||
class WebChatModel(ChatModel):
|
class WebChatModel(ChatModel):
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None:
|
||||||
self.model = None
|
if lazy_init:
|
||||||
self.tokenizer = None
|
self.model = None
|
||||||
self.generating_args = GeneratingArguments()
|
self.tokenizer = None
|
||||||
if len(args) != 0:
|
self.generating_args = GeneratingArguments()
|
||||||
super().__init__(*args)
|
else:
|
||||||
|
super().__init__(args)
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
self,
|
self,
|
||||||
@@ -26,7 +26,7 @@ class WebChatModel(ChatModel):
|
|||||||
finetuning_type: str,
|
finetuning_type: str,
|
||||||
quantization_bit: str,
|
quantization_bit: str,
|
||||||
template: str,
|
template: str,
|
||||||
source_prefix: str
|
system_prompt: str
|
||||||
):
|
):
|
||||||
if self.model is not None:
|
if self.model is not None:
|
||||||
yield ALERTS["err_exists"][lang]
|
yield ALERTS["err_exists"][lang]
|
||||||
@@ -53,11 +53,11 @@ class WebChatModel(ChatModel):
|
|||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||||
template=template,
|
template=template,
|
||||||
source_prefix=source_prefix
|
system_prompt=system_prompt
|
||||||
)
|
)
|
||||||
super().__init__(*get_infer_args(args))
|
super().__init__(args)
|
||||||
|
|
||||||
yield ALERTS["info_loaded"][lang]
|
yield ALERTS["info_loaded"][lang]
|
||||||
|
|
||||||
@@ -73,7 +73,7 @@ class WebChatModel(ChatModel):
|
|||||||
chatbot: List[Tuple[str, str]],
|
chatbot: List[Tuple[str, str]],
|
||||||
query: str,
|
query: str,
|
||||||
history: List[Tuple[str, str]],
|
history: List[Tuple[str, str]],
|
||||||
prefix: str,
|
system: str,
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
temperature: float
|
temperature: float
|
||||||
@@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
|
|||||||
chatbot.append([query, ""])
|
chatbot.append([query, ""])
|
||||||
response = ""
|
response = ""
|
||||||
for new_text in self.stream_chat(
|
for new_text in self.stream_chat(
|
||||||
query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||||
):
|
):
|
||||||
response += new_text
|
response += new_text
|
||||||
response = self.postprocess(response)
|
response = self.postprocess(response)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import gradio as gr
|
|||||||
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
||||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||||
|
|
||||||
from llmtuner.extras.constants import SUPPORTED_MODELS
|
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CACHE_DIR = "cache"
|
DEFAULT_CACHE_DIR = "cache"
|
||||||
@@ -29,14 +29,16 @@ def load_config() -> Dict[str, Any]:
|
|||||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except:
|
except:
|
||||||
return {"last_model": "", "path_dict": {}}
|
return {"lang": "", "last_model": "", "path_dict": {}}
|
||||||
|
|
||||||
|
|
||||||
def save_config(model_name: str, model_path: str) -> None:
|
def save_config(lang: str, model_name: str, model_path: str) -> None:
|
||||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
user_config["last_model"] = model_name
|
user_config["lang"] = lang or user_config["lang"]
|
||||||
user_config["path_dict"][model_name] = model_path
|
if model_name:
|
||||||
|
user_config["last_model"] = model_name
|
||||||
|
user_config["path_dict"][model_name] = model_path
|
||||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||||
json.dump(user_config, f, indent=2, ensure_ascii=False)
|
json.dump(user_config, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
@@ -46,6 +48,12 @@ def get_model_path(model_name: str) -> str:
|
|||||||
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
|
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
|
||||||
|
|
||||||
|
|
||||||
|
def get_template(model_name: str) -> str:
|
||||||
|
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
|
||||||
|
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
|
||||||
|
return "default"
|
||||||
|
|
||||||
|
|
||||||
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||||
checkpoints = []
|
checkpoints = []
|
||||||
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
|
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from llmtuner.webui.components.top import create_top
|
from llmtuner.webui.components.top import create_top
|
||||||
from llmtuner.webui.components.sft import create_sft_tab
|
from llmtuner.webui.components.train import create_train_tab
|
||||||
from llmtuner.webui.components.eval import create_eval_tab
|
from llmtuner.webui.components.eval import create_eval_tab
|
||||||
from llmtuner.webui.components.infer import create_infer_tab
|
from llmtuner.webui.components.infer import create_infer_tab
|
||||||
from llmtuner.webui.components.export import create_export_tab
|
from llmtuner.webui.components.export import create_export_tab
|
||||||
|
from llmtuner.webui.components.chatbot import create_chat_box
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ def create_chat_box(
|
|||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
prefix = gr.Textbox(show_label=False)
|
system = gr.Textbox(show_label=False)
|
||||||
query = gr.Textbox(show_label=False, lines=8)
|
query = gr.Textbox(show_label=False, lines=8)
|
||||||
submit_btn = gr.Button(variant="primary")
|
submit_btn = gr.Button(variant="primary")
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ def create_chat_box(
|
|||||||
|
|
||||||
submit_btn.click(
|
submit_btn.click(
|
||||||
chat_model.predict,
|
chat_model.predict,
|
||||||
[chatbot, query, history, prefix, max_new_tokens, top_p, temperature],
|
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
|
||||||
[chatbot, history],
|
[chatbot, history],
|
||||||
show_progress=True
|
show_progress=True
|
||||||
).then(
|
).then(
|
||||||
@@ -41,7 +41,7 @@ def create_chat_box(
|
|||||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
||||||
|
|
||||||
return chat_box, chatbot, history, dict(
|
return chat_box, chatbot, history, dict(
|
||||||
prefix=prefix,
|
system=system,
|
||||||
query=query,
|
query=query,
|
||||||
submit_btn=submit_btn,
|
submit_btn=submit_btn,
|
||||||
clear_btn=clear_btn,
|
clear_btn=clear_btn,
|
||||||
|
|||||||
@@ -16,6 +16,6 @@ def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"
|
|||||||
|
|
||||||
close_btn = gr.Button()
|
close_btn = gr.Button()
|
||||||
|
|
||||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])
|
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
||||||
|
|
||||||
return preview_box, preview_count, preview_samples, close_btn
|
return preview_box, preview_count, preview_samples, close_btn
|
||||||
|
|||||||
@@ -14,13 +14,18 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||||
preview_btn = gr.Button(interactive=False, scale=1)
|
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||||
|
|
||||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||||
|
|
||||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
||||||
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
|
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
|
||||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
data_preview_btn.click(
|
||||||
|
get_preview,
|
||||||
|
[dataset_dir, dataset],
|
||||||
|
[preview_count, preview_samples, preview_box],
|
||||||
|
queue=False
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||||
@@ -30,38 +35,46 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||||||
predict = gr.Checkbox(value=True)
|
predict = gr.Checkbox(value=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
cmd_preview_btn = gr.Button()
|
||||||
start_btn = gr.Button()
|
start_btn = gr.Button()
|
||||||
stop_btn = gr.Button()
|
stop_btn = gr.Button()
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
process_bar = gr.Slider(visible=False, interactive=False)
|
||||||
|
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
output_box = gr.Markdown()
|
output_box = gr.Markdown()
|
||||||
|
|
||||||
start_btn.click(
|
input_components = [
|
||||||
runner.run_eval,
|
top_elems["lang"],
|
||||||
[
|
top_elems["model_name"],
|
||||||
top_elems["lang"],
|
top_elems["checkpoints"],
|
||||||
top_elems["model_name"],
|
top_elems["finetuning_type"],
|
||||||
top_elems["checkpoints"],
|
top_elems["quantization_bit"],
|
||||||
top_elems["finetuning_type"],
|
top_elems["template"],
|
||||||
top_elems["quantization_bit"],
|
top_elems["system_prompt"],
|
||||||
top_elems["template"],
|
dataset_dir,
|
||||||
top_elems["source_prefix"],
|
dataset,
|
||||||
dataset_dir,
|
max_source_length,
|
||||||
dataset,
|
max_target_length,
|
||||||
max_source_length,
|
max_samples,
|
||||||
max_target_length,
|
batch_size,
|
||||||
max_samples,
|
predict
|
||||||
batch_size,
|
]
|
||||||
predict
|
|
||||||
],
|
output_components = [
|
||||||
[output_box]
|
output_box,
|
||||||
)
|
process_bar
|
||||||
|
]
|
||||||
|
|
||||||
|
cmd_preview_btn.click(runner.preview_eval, input_components, output_components)
|
||||||
|
start_btn.click(runner.run_eval, input_components, output_components)
|
||||||
stop_btn.click(runner.set_abort, queue=False)
|
stop_btn.click(runner.set_abort, queue=False)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
dataset_dir=dataset_dir,
|
dataset_dir=dataset_dir,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
preview_btn=preview_btn,
|
data_preview_btn=data_preview_btn,
|
||||||
preview_count=preview_count,
|
preview_count=preview_count,
|
||||||
preview_samples=preview_samples,
|
preview_samples=preview_samples,
|
||||||
close_btn=close_btn,
|
close_btn=close_btn,
|
||||||
@@ -70,6 +83,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
predict=predict,
|
predict=predict,
|
||||||
|
cmd_preview_btn=cmd_preview_btn,
|
||||||
start_btn=start_btn,
|
start_btn=start_btn,
|
||||||
stop_btn=stop_btn,
|
stop_btn=stop_btn,
|
||||||
output_box=output_box
|
output_box=output_box
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from llmtuner.webui.utils import export_model
|
from llmtuner.webui.utils import save_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from gradio.components import Component
|
from gradio.components import Component
|
||||||
@@ -16,12 +16,13 @@ def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component
|
|||||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||||
|
|
||||||
export_btn.click(
|
export_btn.click(
|
||||||
export_model,
|
save_model,
|
||||||
[
|
[
|
||||||
top_elems["lang"],
|
top_elems["lang"],
|
||||||
top_elems["model_name"],
|
top_elems["model_name"],
|
||||||
top_elems["checkpoints"],
|
top_elems["checkpoints"],
|
||||||
top_elems["finetuning_type"],
|
top_elems["finetuning_type"],
|
||||||
|
top_elems["template"],
|
||||||
max_shard_size,
|
max_shard_size,
|
||||||
save_dir
|
save_dir
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"
|
|||||||
top_elems["finetuning_type"],
|
top_elems["finetuning_type"],
|
||||||
top_elems["quantization_bit"],
|
top_elems["quantization_bit"],
|
||||||
top_elems["template"],
|
top_elems["template"],
|
||||||
top_elems["source_prefix"]
|
top_elems["system_prompt"]
|
||||||
],
|
],
|
||||||
[info_box]
|
[info_box]
|
||||||
).then(
|
).then(
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import gradio as gr
|
|||||||
|
|
||||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||||
from llmtuner.extras.template import templates
|
from llmtuner.extras.template import templates
|
||||||
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
|
from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config
|
||||||
from llmtuner.webui.utils import can_quantize
|
from llmtuner.webui.utils import can_quantize
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -15,27 +15,32 @@ def create_top() -> Dict[str, "Component"]:
|
|||||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
|
lang = gr.Dropdown(choices=["en", "zh"], scale=1)
|
||||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
model_name = gr.Dropdown(choices=available_models, scale=3)
|
||||||
model_path = gr.Textbox(scale=3)
|
model_path = gr.Textbox(scale=3)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
|
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
||||||
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
||||||
refresh_btn = gr.Button(scale=1)
|
refresh_btn = gr.Button(scale=1)
|
||||||
|
|
||||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
quantization_bit = gr.Dropdown([8, 4], scale=1)
|
quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
|
||||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1)
|
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
|
||||||
source_prefix = gr.Textbox(scale=2)
|
system_prompt = gr.Textbox(scale=2)
|
||||||
|
|
||||||
|
lang.change(save_config, [lang, model_name, model_path])
|
||||||
|
|
||||||
model_name.change(
|
model_name.change(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||||
).then(
|
).then(
|
||||||
get_model_path, [model_name], [model_path]
|
get_model_path, [model_name], [model_path]
|
||||||
|
).then(
|
||||||
|
get_template, [model_name], [template]
|
||||||
) # do not save config since the below line will save
|
) # do not save config since the below line will save
|
||||||
model_path.change(save_config, [model_name, model_path])
|
|
||||||
|
model_path.change(save_config, [lang, model_name, model_path])
|
||||||
|
|
||||||
finetuning_type.change(
|
finetuning_type.change(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||||
@@ -43,7 +48,9 @@ def create_top() -> Dict[str, "Component"]:
|
|||||||
can_quantize, [finetuning_type], [quantization_bit]
|
can_quantize, [finetuning_type], [quantization_bit]
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
refresh_btn.click(
|
||||||
|
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||||
|
)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
lang=lang,
|
lang=lang,
|
||||||
@@ -55,5 +62,5 @@ def create_top() -> Dict[str, "Component"]:
|
|||||||
advanced_tab=advanced_tab,
|
advanced_tab=advanced_tab,
|
||||||
quantization_bit=quantization_bit,
|
quantization_bit=quantization_bit,
|
||||||
template=template,
|
template=template,
|
||||||
source_prefix=source_prefix
|
system_prompt=system_prompt
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ from transformers.trainer_utils import SchedulerType
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
from llmtuner.extras.constants import STAGES
|
||||||
|
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
|
||||||
from llmtuner.webui.components.data import create_preview_box
|
from llmtuner.webui.components.data import create_preview_box
|
||||||
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
||||||
|
|
||||||
@@ -12,17 +13,23 @@ if TYPE_CHECKING:
|
|||||||
from llmtuner.webui.runner import Runner
|
from llmtuner.webui.runner import Runner
|
||||||
|
|
||||||
|
|
||||||
def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2)
|
||||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||||
preview_btn = gr.Button(interactive=False, scale=1)
|
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||||
|
|
||||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||||
|
|
||||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
||||||
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
|
dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
|
||||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
data_preview_btn.click(
|
||||||
|
get_preview,
|
||||||
|
[dataset_dir, dataset],
|
||||||
|
[preview_count, preview_samples, preview_box],
|
||||||
|
queue=False
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||||
@@ -35,10 +42,10 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||||||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||||
lr_scheduler_type = gr.Dropdown(
|
lr_scheduler_type = gr.Dropdown(
|
||||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
|
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
|
||||||
)
|
)
|
||||||
max_grad_norm = gr.Textbox(value="1.0")
|
max_grad_norm = gr.Textbox(value="1.0")
|
||||||
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||||
|
|
||||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@@ -46,20 +53,40 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||||
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
||||||
|
padding_side = gr.Radio(choices=["left", "right"], value="left")
|
||||||
|
|
||||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||||
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1)
|
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||||
lora_target = gr.Textbox(scale=2)
|
lora_target = gr.Textbox(scale=2)
|
||||||
|
resume_lora_training = gr.Checkbox(value=True, scale=1)
|
||||||
|
|
||||||
|
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||||
|
with gr.Row():
|
||||||
|
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2)
|
||||||
|
reward_model = gr.Dropdown(scale=2)
|
||||||
|
refresh_btn = gr.Button(scale=1)
|
||||||
|
|
||||||
|
refresh_btn.click(
|
||||||
|
list_checkpoint,
|
||||||
|
[top_elems["model_name"], top_elems["finetuning_type"]],
|
||||||
|
[reward_model],
|
||||||
|
queue=False
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
cmd_preview_btn = gr.Button()
|
||||||
start_btn = gr.Button()
|
start_btn = gr.Button()
|
||||||
stop_btn = gr.Button()
|
stop_btn = gr.Button()
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
output_dir = gr.Textbox()
|
with gr.Row():
|
||||||
|
output_dir = gr.Textbox()
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
process_bar = gr.Slider(visible=False, interactive=False)
|
||||||
|
|
||||||
with gr.Box():
|
with gr.Box():
|
||||||
output_box = gr.Markdown()
|
output_box = gr.Markdown()
|
||||||
@@ -67,49 +94,59 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
loss_viewer = gr.Plot()
|
loss_viewer = gr.Plot()
|
||||||
|
|
||||||
start_btn.click(
|
input_components = [
|
||||||
runner.run_train,
|
top_elems["lang"],
|
||||||
[
|
top_elems["model_name"],
|
||||||
top_elems["lang"],
|
top_elems["checkpoints"],
|
||||||
top_elems["model_name"],
|
top_elems["finetuning_type"],
|
||||||
top_elems["checkpoints"],
|
top_elems["quantization_bit"],
|
||||||
top_elems["finetuning_type"],
|
top_elems["template"],
|
||||||
top_elems["quantization_bit"],
|
top_elems["system_prompt"],
|
||||||
top_elems["template"],
|
training_stage,
|
||||||
top_elems["source_prefix"],
|
dataset_dir,
|
||||||
dataset_dir,
|
dataset,
|
||||||
dataset,
|
max_source_length,
|
||||||
max_source_length,
|
max_target_length,
|
||||||
max_target_length,
|
learning_rate,
|
||||||
learning_rate,
|
num_train_epochs,
|
||||||
num_train_epochs,
|
max_samples,
|
||||||
max_samples,
|
batch_size,
|
||||||
batch_size,
|
gradient_accumulation_steps,
|
||||||
gradient_accumulation_steps,
|
lr_scheduler_type,
|
||||||
lr_scheduler_type,
|
max_grad_norm,
|
||||||
max_grad_norm,
|
val_size,
|
||||||
dev_ratio,
|
logging_steps,
|
||||||
logging_steps,
|
save_steps,
|
||||||
save_steps,
|
warmup_steps,
|
||||||
warmup_steps,
|
compute_type,
|
||||||
compute_type,
|
padding_side,
|
||||||
lora_rank,
|
lora_rank,
|
||||||
lora_dropout,
|
lora_dropout,
|
||||||
lora_target,
|
lora_target,
|
||||||
output_dir
|
resume_lora_training,
|
||||||
],
|
dpo_beta,
|
||||||
[output_box]
|
reward_model,
|
||||||
)
|
output_dir
|
||||||
|
]
|
||||||
|
|
||||||
|
output_components = [
|
||||||
|
output_box,
|
||||||
|
process_bar
|
||||||
|
]
|
||||||
|
|
||||||
|
cmd_preview_btn.click(runner.preview_train, input_components, output_components)
|
||||||
|
start_btn.click(runner.run_train, input_components, output_components)
|
||||||
stop_btn.click(runner.set_abort, queue=False)
|
stop_btn.click(runner.set_abort, queue=False)
|
||||||
|
|
||||||
output_box.change(
|
process_bar.change(
|
||||||
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
|
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
|
||||||
)
|
)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
|
training_stage=training_stage,
|
||||||
dataset_dir=dataset_dir,
|
dataset_dir=dataset_dir,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
preview_btn=preview_btn,
|
data_preview_btn=data_preview_btn,
|
||||||
preview_count=preview_count,
|
preview_count=preview_count,
|
||||||
preview_samples=preview_samples,
|
preview_samples=preview_samples,
|
||||||
close_btn=close_btn,
|
close_btn=close_btn,
|
||||||
@@ -122,16 +159,23 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
|
|||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
lr_scheduler_type=lr_scheduler_type,
|
lr_scheduler_type=lr_scheduler_type,
|
||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
dev_ratio=dev_ratio,
|
val_size=val_size,
|
||||||
advanced_tab=advanced_tab,
|
advanced_tab=advanced_tab,
|
||||||
logging_steps=logging_steps,
|
logging_steps=logging_steps,
|
||||||
save_steps=save_steps,
|
save_steps=save_steps,
|
||||||
warmup_steps=warmup_steps,
|
warmup_steps=warmup_steps,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
|
padding_side=padding_side,
|
||||||
lora_tab=lora_tab,
|
lora_tab=lora_tab,
|
||||||
lora_rank=lora_rank,
|
lora_rank=lora_rank,
|
||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
lora_target=lora_target,
|
lora_target=lora_target,
|
||||||
|
resume_lora_training=resume_lora_training,
|
||||||
|
rlhf_tab=rlhf_tab,
|
||||||
|
dpo_beta=dpo_beta,
|
||||||
|
reward_model=reward_model,
|
||||||
|
refresh_btn=refresh_btn,
|
||||||
|
cmd_preview_btn=cmd_preview_btn,
|
||||||
start_btn=start_btn,
|
start_btn=start_btn,
|
||||||
stop_btn=stop_btn,
|
stop_btn=stop_btn,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
@@ -3,11 +3,13 @@ from transformers.utils.versions import require_version
|
|||||||
|
|
||||||
from llmtuner.webui.components import (
|
from llmtuner.webui.components import (
|
||||||
create_top,
|
create_top,
|
||||||
create_sft_tab,
|
create_train_tab,
|
||||||
create_eval_tab,
|
create_eval_tab,
|
||||||
create_infer_tab,
|
create_infer_tab,
|
||||||
create_export_tab
|
create_export_tab,
|
||||||
|
create_chat_box
|
||||||
)
|
)
|
||||||
|
from llmtuner.webui.chat import WebChatModel
|
||||||
from llmtuner.webui.css import CSS
|
from llmtuner.webui.css import CSS
|
||||||
from llmtuner.webui.manager import Manager
|
from llmtuner.webui.manager import Manager
|
||||||
from llmtuner.webui.runner import Runner
|
from llmtuner.webui.runner import Runner
|
||||||
@@ -22,8 +24,8 @@ def create_ui() -> gr.Blocks:
|
|||||||
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
||||||
top_elems = create_top()
|
top_elems = create_top()
|
||||||
|
|
||||||
with gr.Tab("SFT"):
|
with gr.Tab("Train"):
|
||||||
sft_elems = create_sft_tab(top_elems, runner)
|
train_elems = create_train_tab(top_elems, runner)
|
||||||
|
|
||||||
with gr.Tab("Evaluate"):
|
with gr.Tab("Evaluate"):
|
||||||
eval_elems = create_eval_tab(top_elems, runner)
|
eval_elems = create_eval_tab(top_elems, runner)
|
||||||
@@ -34,7 +36,7 @@ def create_ui() -> gr.Blocks:
|
|||||||
with gr.Tab("Export"):
|
with gr.Tab("Export"):
|
||||||
export_elems = create_export_tab(top_elems)
|
export_elems = create_export_tab(top_elems)
|
||||||
|
|
||||||
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
|
elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems]
|
||||||
manager = Manager(elem_list)
|
manager = Manager(elem_list)
|
||||||
|
|
||||||
demo.load(
|
demo.load(
|
||||||
@@ -47,11 +49,29 @@ def create_ui() -> gr.Blocks:
|
|||||||
manager.gen_label,
|
manager.gen_label,
|
||||||
[top_elems["lang"]],
|
[top_elems["lang"]],
|
||||||
[elem for elems in elem_list for elem in elems.values()],
|
[elem for elems in elem_list for elem in elems.values()],
|
||||||
|
queue=False
|
||||||
)
|
)
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
|
||||||
|
|
||||||
|
def create_web_demo() -> gr.Blocks:
|
||||||
|
chat_model = WebChatModel(lazy_init=False)
|
||||||
|
|
||||||
|
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||||
|
lang = gr.Dropdown(choices=["en", "zh"])
|
||||||
|
|
||||||
|
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
|
||||||
|
|
||||||
|
manager = Manager([{"lang": lang}, chat_elems])
|
||||||
|
|
||||||
|
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
||||||
|
|
||||||
|
lang.select(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False)
|
||||||
|
|
||||||
|
return demo
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo = create_ui()
|
demo = create_ui()
|
||||||
demo.queue()
|
demo.queue()
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ LOCALES = {
|
|||||||
"info": "构建提示词时使用的模板"
|
"info": "构建提示词时使用的模板"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"source_prefix": {
|
"system_prompt": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "System prompt (optional)",
|
"label": "System prompt (optional)",
|
||||||
"info": "A sequence used as the default system prompt."
|
"info": "A sequence used as the default system prompt."
|
||||||
@@ -87,6 +87,16 @@ LOCALES = {
|
|||||||
"info": "默认使用的系统提示词"
|
"info": "默认使用的系统提示词"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"training_stage": {
|
||||||
|
"en": {
|
||||||
|
"label": "Stage",
|
||||||
|
"info": "The stage to perform in training."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "训练阶段",
|
||||||
|
"info": "目前采用的训练方式。"
|
||||||
|
}
|
||||||
|
},
|
||||||
"dataset_dir": {
|
"dataset_dir": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Data dir",
|
"label": "Data dir",
|
||||||
@@ -105,12 +115,12 @@ LOCALES = {
|
|||||||
"label": "数据集"
|
"label": "数据集"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"preview_btn": {
|
"data_preview_btn": {
|
||||||
"en": {
|
"en": {
|
||||||
"value": "Preview"
|
"value": "Preview dataset"
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"value": "预览"
|
"value": "预览数据集"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"preview_count": {
|
"preview_count": {
|
||||||
@@ -227,9 +237,9 @@ LOCALES = {
|
|||||||
"info": "用于梯度裁剪的范数。"
|
"info": "用于梯度裁剪的范数。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"dev_ratio": {
|
"val_size": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Dev ratio",
|
"label": "Val size",
|
||||||
"info": "Proportion of data in the dev set."
|
"info": "Proportion of data in the dev set."
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
@@ -277,6 +287,16 @@ LOCALES = {
|
|||||||
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"padding_side": {
|
||||||
|
"en": {
|
||||||
|
"label": "Padding side",
|
||||||
|
"info": "The side on which the model should have padding applied."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "填充位置",
|
||||||
|
"info": "使用左填充或右填充。"
|
||||||
|
}
|
||||||
|
},
|
||||||
"lora_tab": {
|
"lora_tab": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "LoRA configurations"
|
"label": "LoRA configurations"
|
||||||
@@ -315,6 +335,52 @@ LOCALES = {
|
|||||||
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
|
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"resume_lora_training": {
|
||||||
|
"en": {
|
||||||
|
"label": "Resume LoRA training",
|
||||||
|
"info": "Whether to resume training from the last LoRA weights or create new lora weights."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "继续上次的训练",
|
||||||
|
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rlhf_tab": {
|
||||||
|
"en": {
|
||||||
|
"label": "RLHF configurations"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "RLHF 参数设置"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"dpo_beta": {
|
||||||
|
"en": {
|
||||||
|
"label": "DPO beta",
|
||||||
|
"info": "Value of the beta parameter in the DPO loss."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "DPO beta 参数",
|
||||||
|
"info": "DPO 损失函数中 beta 超参数大小。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"reward_model": {
|
||||||
|
"en": {
|
||||||
|
"label": "Reward model",
|
||||||
|
"info": "Checkpoint of the reward model for PPO training."
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "奖励模型",
|
||||||
|
"info": "PPO 训练中奖励模型的断点路径。"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"cmd_preview_btn": {
|
||||||
|
"en": {
|
||||||
|
"value": "Preview command"
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"value": "预览命令"
|
||||||
|
}
|
||||||
|
},
|
||||||
"start_btn": {
|
"start_btn": {
|
||||||
"en": {
|
"en": {
|
||||||
"value": "Start"
|
"value": "Start"
|
||||||
@@ -389,7 +455,7 @@ LOCALES = {
|
|||||||
"value": "模型未加载,请先加载模型。"
|
"value": "模型未加载,请先加载模型。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"prefix": {
|
"system": {
|
||||||
"en": {
|
"en": {
|
||||||
"placeholder": "System prompt (optional)"
|
"placeholder": "System prompt (optional)"
|
||||||
},
|
},
|
||||||
@@ -513,6 +579,10 @@ ALERTS = {
|
|||||||
"en": "Please provide export dir.",
|
"en": "Please provide export dir.",
|
||||||
"zh": "请填写导出目录"
|
"zh": "请填写导出目录"
|
||||||
},
|
},
|
||||||
|
"err_failed": {
|
||||||
|
"en": "Failed.",
|
||||||
|
"zh": "训练出错。"
|
||||||
|
},
|
||||||
"info_aborting": {
|
"info_aborting": {
|
||||||
"en": "Aborted, wait for terminating...",
|
"en": "Aborted, wait for terminating...",
|
||||||
"zh": "训练中断,正在等待线程结束……"
|
"zh": "训练中断,正在等待线程结束……"
|
||||||
|
|||||||
@@ -12,12 +12,18 @@ class Manager:
|
|||||||
def __init__(self, elem_list: List[Dict[str, Component]]):
|
def __init__(self, elem_list: List[Dict[str, Component]]):
|
||||||
self.elem_list = elem_list
|
self.elem_list = elem_list
|
||||||
|
|
||||||
def gen_refresh(self) -> Dict[str, Any]:
|
def gen_refresh(self, lang: str) -> Dict[str, Any]:
|
||||||
refresh_dict = {
|
refresh_dict = {
|
||||||
"dataset": {"choices": list_dataset()["choices"]},
|
"dataset": {"choices": list_dataset()["choices"]},
|
||||||
"output_dir": {"value": get_time()}
|
"output_dir": {"value": get_time()}
|
||||||
}
|
}
|
||||||
|
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
if lang:
|
||||||
|
refresh_dict["lang"] = {"value": lang}
|
||||||
|
else:
|
||||||
|
refresh_dict["lang"] = {"value": user_config["lang"] if user_config["lang"] else "en"}
|
||||||
|
|
||||||
if user_config["last_model"]:
|
if user_config["last_model"]:
|
||||||
refresh_dict["model_name"] = {"value": user_config["last_model"]}
|
refresh_dict["model_name"] = {"value": user_config["last_model"]}
|
||||||
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
|
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
|
||||||
@@ -26,10 +32,12 @@ class Manager:
|
|||||||
|
|
||||||
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
|
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
|
||||||
update_dict = {}
|
update_dict = {}
|
||||||
refresh_dict = self.gen_refresh()
|
refresh_dict = self.gen_refresh(lang)
|
||||||
|
|
||||||
for elems in self.elem_list:
|
for elems in self.elem_list:
|
||||||
for name, component in elems.items():
|
for name, component in elems.items():
|
||||||
update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {}))
|
update_dict[component] = gr.update(
|
||||||
|
**LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {})
|
||||||
|
)
|
||||||
|
|
||||||
return update_dict
|
return update_dict
|
||||||
|
|||||||
@@ -1,18 +1,20 @@
|
|||||||
|
import gradio as gr
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import transformers
|
import transformers
|
||||||
from typing import Generator, List, Optional, Tuple
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
|
from typing import Any, Dict, Generator, List, Tuple
|
||||||
|
|
||||||
from llmtuner.extras.callbacks import LogCallback
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
from llmtuner.extras.constants import DEFAULT_MODULE
|
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||||
from llmtuner.extras.logging import LoggerHandler
|
from llmtuner.extras.logging import LoggerHandler
|
||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.tuner import get_train_args, run_sft
|
from llmtuner.tuner import run_exp
|
||||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||||
from llmtuner.webui.locales import ALERTS
|
from llmtuner.webui.locales import ALERTS
|
||||||
from llmtuner.webui.utils import format_info, get_eval_results
|
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||||
|
|
||||||
|
|
||||||
class Runner:
|
class Runner:
|
||||||
@@ -20,49 +22,46 @@ class Runner:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.aborted = False
|
self.aborted = False
|
||||||
self.running = False
|
self.running = False
|
||||||
|
self.logger_handler = LoggerHandler()
|
||||||
|
self.logger_handler.setLevel(logging.INFO)
|
||||||
|
logging.root.addHandler(self.logger_handler)
|
||||||
|
transformers.logging.add_handler(self.logger_handler)
|
||||||
|
|
||||||
def set_abort(self):
|
def set_abort(self):
|
||||||
self.aborted = True
|
self.aborted = True
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
def initialize(
|
def _initialize(
|
||||||
self, lang: str, model_name: str, dataset: List[str]
|
self, lang: str, model_name: str, dataset: List[str]
|
||||||
) -> Tuple[str, str, LoggerHandler, LogCallback]:
|
) -> str:
|
||||||
if self.running:
|
if self.running:
|
||||||
return None, ALERTS["err_conflict"][lang], None, None
|
return ALERTS["err_conflict"][lang]
|
||||||
|
|
||||||
if not model_name:
|
if not model_name:
|
||||||
return None, ALERTS["err_no_model"][lang], None, None
|
return ALERTS["err_no_model"][lang]
|
||||||
|
|
||||||
model_name_or_path = get_model_path(model_name)
|
if not get_model_path(model_name):
|
||||||
if not model_name_or_path:
|
return ALERTS["err_no_path"][lang]
|
||||||
return None, ALERTS["err_no_path"][lang], None, None
|
|
||||||
|
|
||||||
if len(dataset) == 0:
|
if len(dataset) == 0:
|
||||||
return None, ALERTS["err_no_dataset"][lang], None, None
|
return ALERTS["err_no_dataset"][lang]
|
||||||
|
|
||||||
self.aborted = False
|
self.aborted = False
|
||||||
self.running = True
|
self.logger_handler.reset()
|
||||||
|
self.trainer_callback = LogCallback(self)
|
||||||
|
return ""
|
||||||
|
|
||||||
logger_handler = LoggerHandler()
|
def _finalize(
|
||||||
logger_handler.setLevel(logging.INFO)
|
self, lang: str, finish_info: str
|
||||||
logging.root.addHandler(logger_handler)
|
|
||||||
transformers.logging.add_handler(logger_handler)
|
|
||||||
trainer_callback = LogCallback(self)
|
|
||||||
|
|
||||||
return model_name_or_path, "", logger_handler, trainer_callback
|
|
||||||
|
|
||||||
def finalize(
|
|
||||||
self, lang: str, finish_info: Optional[str] = None
|
|
||||||
) -> str:
|
) -> str:
|
||||||
self.running = False
|
self.running = False
|
||||||
torch_gc()
|
torch_gc()
|
||||||
if self.aborted:
|
if self.aborted:
|
||||||
return ALERTS["info_aborted"][lang]
|
return ALERTS["info_aborted"][lang]
|
||||||
else:
|
else:
|
||||||
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
|
return finish_info
|
||||||
|
|
||||||
def run_train(
|
def _parse_train_args(
|
||||||
self,
|
self,
|
||||||
lang: str,
|
lang: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@@ -70,7 +69,8 @@ class Runner:
|
|||||||
finetuning_type: str,
|
finetuning_type: str,
|
||||||
quantization_bit: str,
|
quantization_bit: str,
|
||||||
template: str,
|
template: str,
|
||||||
source_prefix: str,
|
system_prompt: str,
|
||||||
|
training_stage: str,
|
||||||
dataset_dir: str,
|
dataset_dir: str,
|
||||||
dataset: List[str],
|
dataset: List[str],
|
||||||
max_source_length: int,
|
max_source_length: int,
|
||||||
@@ -82,37 +82,39 @@ class Runner:
|
|||||||
gradient_accumulation_steps: int,
|
gradient_accumulation_steps: int,
|
||||||
lr_scheduler_type: str,
|
lr_scheduler_type: str,
|
||||||
max_grad_norm: str,
|
max_grad_norm: str,
|
||||||
dev_ratio: float,
|
val_size: float,
|
||||||
logging_steps: int,
|
logging_steps: int,
|
||||||
save_steps: int,
|
save_steps: int,
|
||||||
warmup_steps: int,
|
warmup_steps: int,
|
||||||
compute_type: str,
|
compute_type: str,
|
||||||
|
padding_side: str,
|
||||||
lora_rank: int,
|
lora_rank: int,
|
||||||
lora_dropout: float,
|
lora_dropout: float,
|
||||||
lora_target: str,
|
lora_target: str,
|
||||||
|
resume_lora_training: bool,
|
||||||
|
dpo_beta: float,
|
||||||
|
reward_model: str,
|
||||||
output_dir: str
|
output_dir: str
|
||||||
) -> Generator[str, None, None]:
|
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
|
||||||
if error:
|
|
||||||
yield error
|
|
||||||
return
|
|
||||||
|
|
||||||
if checkpoints:
|
if checkpoints:
|
||||||
checkpoint_dir = ",".join(
|
checkpoint_dir = ",".join(
|
||||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
[os.path.join(get_save_dir(model_name), finetuning_type, ckpt) for ckpt in checkpoints]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
checkpoint_dir = None
|
checkpoint_dir = None
|
||||||
|
|
||||||
|
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=model_name_or_path,
|
stage="sft",
|
||||||
|
model_name_or_path=get_model_path(model_name),
|
||||||
do_train=True,
|
do_train=True,
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||||
template=template,
|
template=template,
|
||||||
source_prefix=source_prefix,
|
system_prompt=system_prompt,
|
||||||
dataset_dir=dataset_dir,
|
dataset_dir=dataset_dir,
|
||||||
dataset=",".join(dataset),
|
dataset=",".join(dataset),
|
||||||
max_source_length=max_source_length,
|
max_source_length=max_source_length,
|
||||||
@@ -127,42 +129,40 @@ class Runner:
|
|||||||
logging_steps=logging_steps,
|
logging_steps=logging_steps,
|
||||||
save_steps=save_steps,
|
save_steps=save_steps,
|
||||||
warmup_steps=warmup_steps,
|
warmup_steps=warmup_steps,
|
||||||
fp16=(compute_type == "fp16"),
|
padding_side=padding_side,
|
||||||
bf16=(compute_type == "bf16"),
|
|
||||||
lora_rank=lora_rank,
|
lora_rank=lora_rank,
|
||||||
lora_dropout=lora_dropout,
|
lora_dropout=lora_dropout,
|
||||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
resume_lora_training=resume_lora_training,
|
||||||
|
output_dir=output_dir
|
||||||
)
|
)
|
||||||
|
args[compute_type] = True
|
||||||
|
|
||||||
if dev_ratio > 1e-6:
|
if training_stage == "Reward Modeling":
|
||||||
args["dev_ratio"] = dev_ratio
|
args["stage"] = "rm"
|
||||||
|
args["resume_lora_training"] = False
|
||||||
|
elif training_stage == "PPO":
|
||||||
|
args["stage"] = "ppo"
|
||||||
|
args["resume_lora_training"] = False
|
||||||
|
args["reward_model"] = reward_model
|
||||||
|
args["padding_side"] = "left"
|
||||||
|
val_size = 0
|
||||||
|
elif training_stage == "DPO":
|
||||||
|
args["stage"] = "dpo"
|
||||||
|
args["resume_lora_training"] = False
|
||||||
|
args["dpo_beta"] = dpo_beta
|
||||||
|
elif training_stage == "Pre-Training":
|
||||||
|
args["stage"] = "pt"
|
||||||
|
|
||||||
|
if val_size > 1e-6:
|
||||||
|
args["val_size"] = val_size
|
||||||
args["evaluation_strategy"] = "steps"
|
args["evaluation_strategy"] = "steps"
|
||||||
args["eval_steps"] = save_steps
|
args["eval_steps"] = save_steps
|
||||||
args["load_best_model_at_end"] = True
|
args["load_best_model_at_end"] = True
|
||||||
|
|
||||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
return lang, model_name, dataset, output_dir, args
|
||||||
|
|
||||||
run_args = dict(
|
def _parse_eval_args(
|
||||||
model_args=model_args,
|
|
||||||
data_args=data_args,
|
|
||||||
training_args=training_args,
|
|
||||||
finetuning_args=finetuning_args,
|
|
||||||
callbacks=[trainer_callback]
|
|
||||||
)
|
|
||||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
|
||||||
thread.start()
|
|
||||||
|
|
||||||
while thread.is_alive():
|
|
||||||
time.sleep(1)
|
|
||||||
if self.aborted:
|
|
||||||
yield ALERTS["info_aborting"][lang]
|
|
||||||
else:
|
|
||||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
|
||||||
|
|
||||||
yield self.finalize(lang)
|
|
||||||
|
|
||||||
def run_eval(
|
|
||||||
self,
|
self,
|
||||||
lang: str,
|
lang: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@@ -170,7 +170,7 @@ class Runner:
|
|||||||
finetuning_type: str,
|
finetuning_type: str,
|
||||||
quantization_bit: str,
|
quantization_bit: str,
|
||||||
template: str,
|
template: str,
|
||||||
source_prefix: str,
|
system_prompt: str,
|
||||||
dataset_dir: str,
|
dataset_dir: str,
|
||||||
dataset: List[str],
|
dataset: List[str],
|
||||||
max_source_length: int,
|
max_source_length: int,
|
||||||
@@ -178,12 +178,7 @@ class Runner:
|
|||||||
max_samples: str,
|
max_samples: str,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
predict: bool
|
predict: bool
|
||||||
) -> Generator[str, None, None]:
|
) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
|
||||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
|
||||||
if error:
|
|
||||||
yield error
|
|
||||||
return
|
|
||||||
|
|
||||||
if checkpoints:
|
if checkpoints:
|
||||||
checkpoint_dir = ",".join(
|
checkpoint_dir = ",".join(
|
||||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||||
@@ -194,15 +189,16 @@ class Runner:
|
|||||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
|
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=model_name_or_path,
|
stage="sft",
|
||||||
|
model_name_or_path=get_model_path(model_name),
|
||||||
do_eval=True,
|
do_eval=True,
|
||||||
overwrite_cache=True,
|
overwrite_cache=True,
|
||||||
predict_with_generate=True,
|
predict_with_generate=True,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
finetuning_type=finetuning_type,
|
finetuning_type=finetuning_type,
|
||||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
|
||||||
template=template,
|
template=template,
|
||||||
source_prefix=source_prefix,
|
system_prompt=system_prompt,
|
||||||
dataset_dir=dataset_dir,
|
dataset_dir=dataset_dir,
|
||||||
dataset=",".join(dataset),
|
dataset=",".join(dataset),
|
||||||
max_source_length=max_source_length,
|
max_source_length=max_source_length,
|
||||||
@@ -216,23 +212,72 @@ class Runner:
|
|||||||
args.pop("do_eval", None)
|
args.pop("do_eval", None)
|
||||||
args["do_predict"] = True
|
args["do_predict"] = True
|
||||||
|
|
||||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
return lang, model_name, dataset, output_dir, args
|
||||||
|
|
||||||
run_args = dict(
|
def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
model_args=model_args,
|
lang, model_name, dataset, _, args = self._parse_train_args(*args)
|
||||||
data_args=data_args,
|
error = self._initialize(lang, model_name, dataset)
|
||||||
training_args=training_args,
|
if error:
|
||||||
finetuning_args=finetuning_args,
|
yield error, gr.update(visible=False)
|
||||||
callbacks=[trainer_callback]
|
else:
|
||||||
)
|
yield gen_cmd(args), gr.update(visible=False)
|
||||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
|
||||||
|
def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
|
lang, model_name, dataset, _, args = self._parse_eval_args(*args)
|
||||||
|
error = self._initialize(lang, model_name, dataset)
|
||||||
|
if error:
|
||||||
|
yield error, gr.update(visible=False)
|
||||||
|
else:
|
||||||
|
yield gen_cmd(args), gr.update(visible=False)
|
||||||
|
|
||||||
|
def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||||
|
lang, model_name, dataset, output_dir, args = self._parse_train_args(*args)
|
||||||
|
error = self._initialize(lang, model_name, dataset)
|
||||||
|
if error:
|
||||||
|
yield error, gr.update(visible=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||||
|
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
while thread.is_alive():
|
while thread.is_alive():
|
||||||
time.sleep(1)
|
time.sleep(2)
|
||||||
if self.aborted:
|
if self.aborted:
|
||||||
yield ALERTS["info_aborting"][lang]
|
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
||||||
else:
|
else:
|
||||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
|
||||||
|
|
||||||
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
|
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
||||||
|
finish_info = ALERTS["info_finished"][lang]
|
||||||
|
else:
|
||||||
|
finish_info = ALERTS["err_failed"][lang]
|
||||||
|
|
||||||
|
yield self._finalize(lang, finish_info), gr.update(visible=False)
|
||||||
|
|
||||||
|
def run_eval(self, *args) -> Generator[str, None, None]:
|
||||||
|
lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args)
|
||||||
|
error = self._initialize(lang, model_name, dataset)
|
||||||
|
if error:
|
||||||
|
yield error, gr.update(visible=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||||
|
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
while thread.is_alive():
|
||||||
|
time.sleep(2)
|
||||||
|
if self.aborted:
|
||||||
|
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
||||||
|
else:
|
||||||
|
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(output_dir, "all_results.json")):
|
||||||
|
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
|
||||||
|
else:
|
||||||
|
finish_info = ALERTS["err_failed"][lang]
|
||||||
|
|
||||||
|
yield self._finalize(lang, finish_info), gr.update(visible=False)
|
||||||
|
|||||||
@@ -3,22 +3,30 @@ import json
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import matplotlib.figure
|
import matplotlib.figure
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from typing import Any, Dict, Generator, List, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from llmtuner.extras.ploting import smooth
|
from llmtuner.extras.ploting import smooth
|
||||||
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
|
from llmtuner.tuner import export_model
|
||||||
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
|
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
|
||||||
from llmtuner.webui.locales import ALERTS
|
from llmtuner.webui.locales import ALERTS
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from llmtuner.extras.callbacks import LogCallback
|
||||||
|
|
||||||
def format_info(log: str, tracker: dict) -> str:
|
|
||||||
info = log
|
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
|
||||||
if "current_steps" in tracker:
|
if not callback.max_steps:
|
||||||
info += "Running **{:d}/{:d}**: {} < {}\n".format(
|
return gr.update(visible=False)
|
||||||
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
|
|
||||||
)
|
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
|
||||||
return info
|
label = "Running {:d}/{:d}: {} < {}".format(
|
||||||
|
callback.cur_steps,
|
||||||
|
callback.max_steps,
|
||||||
|
callback.elapsed_time,
|
||||||
|
callback.remaining_time
|
||||||
|
)
|
||||||
|
return gr.update(label=label, value=percentage, visible=True)
|
||||||
|
|
||||||
|
|
||||||
def get_time() -> str:
|
def get_time() -> str:
|
||||||
@@ -54,6 +62,18 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
|||||||
return gr.update(interactive=True)
|
return gr.update(interactive=True)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||||
|
if args.get("do_train", None):
|
||||||
|
args["plot_loss"] = True
|
||||||
|
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
|
||||||
|
for k, v in args.items():
|
||||||
|
if v is not None and v != "":
|
||||||
|
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||||
|
cmd_text = "\\\n".join(cmd_lines)
|
||||||
|
cmd_text = "```bash\n{}\n```".format(cmd_text)
|
||||||
|
return cmd_text
|
||||||
|
|
||||||
|
|
||||||
def get_eval_results(path: os.PathLike) -> str:
|
def get_eval_results(path: os.PathLike) -> str:
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
result = json.dumps(json.load(f), indent=4)
|
result = json.dumps(json.load(f), indent=4)
|
||||||
@@ -87,8 +107,14 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
|||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def export_model(
|
def save_model(
|
||||||
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
|
lang: str,
|
||||||
|
model_name: str,
|
||||||
|
checkpoints: List[str],
|
||||||
|
finetuning_type: str,
|
||||||
|
template: str,
|
||||||
|
max_shard_size: int,
|
||||||
|
save_dir: str
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
if not model_name:
|
if not model_name:
|
||||||
yield ALERTS["err_no_model"][lang]
|
yield ALERTS["err_no_model"][lang]
|
||||||
@@ -114,12 +140,11 @@ def export_model(
|
|||||||
args = dict(
|
args = dict(
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
finetuning_type=finetuning_type
|
finetuning_type=finetuning_type,
|
||||||
|
template=template,
|
||||||
|
output_dir=save_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
yield ALERTS["info_exporting"][lang]
|
yield ALERTS["info_exporting"][lang]
|
||||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
export_model(args, max_shard_size="{}GB".format(max_shard_size))
|
||||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
|
||||||
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
|
|
||||||
tokenizer.save_pretrained(save_dir)
|
|
||||||
yield ALERTS["info_exported"][lang]
|
yield ALERTS["info_exported"][lang]
|
||||||
|
|||||||
@@ -1,17 +1,8 @@
|
|||||||
from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
|
from llmtuner import run_exp
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model_args, data_args, training_args, finetuning_args, general_args = get_train_args()
|
run_exp()
|
||||||
|
|
||||||
if general_args.stage == "pt":
|
|
||||||
run_pt(model_args, data_args, training_args, finetuning_args)
|
|
||||||
elif general_args.stage == "sft":
|
|
||||||
run_sft(model_args, data_args, training_args, finetuning_args)
|
|
||||||
elif general_args.stage == "rm":
|
|
||||||
run_rm(model_args, data_args, training_args, finetuning_args)
|
|
||||||
elif general_args.stage == "ppo":
|
|
||||||
run_ppo(model_args, data_args, training_args, finetuning_args)
|
|
||||||
|
|
||||||
|
|
||||||
def _mp_fn(index):
|
def _mp_fn(index):
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from llmtuner.webui.interface import create_ui
|
from llmtuner import create_ui
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -1,33 +1,8 @@
|
|||||||
# coding=utf-8
|
from llmtuner import create_web_demo
|
||||||
# Implements user interface in browser for fine-tuned models.
|
|
||||||
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
|
|
||||||
from llmtuner.tuner import get_infer_args
|
|
||||||
from llmtuner.webui.chat import WebChatModel
|
|
||||||
from llmtuner.webui.components.chatbot import create_chat_box
|
|
||||||
from llmtuner.webui.manager import Manager
|
|
||||||
|
|
||||||
|
|
||||||
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
chat_model = WebChatModel(*get_infer_args())
|
demo = create_web_demo()
|
||||||
|
|
||||||
with gr.Blocks(title="Web Demo") as demo:
|
|
||||||
lang = gr.Dropdown(choices=["en", "zh"], value="en")
|
|
||||||
|
|
||||||
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
|
|
||||||
|
|
||||||
manager = Manager([{"lang": lang}, chat_elems])
|
|
||||||
|
|
||||||
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
|
||||||
|
|
||||||
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
|
||||||
|
|
||||||
demo.queue()
|
demo.queue()
|
||||||
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user